Browse Source

Fix typing issues and improve typing completeness across the library

Co-authored-by: Danny <[email protected]>
Co-authored-by: Josh <[email protected]>
pull/7681/head
Stocker 3 years ago
committed by GitHub
parent
commit
5aa696ccfa
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 21
      discord/__main__.py
  2. 56
      discord/abc.py
  3. 42
      discord/activity.py
  4. 22
      discord/app_commands/commands.py
  5. 8
      discord/app_commands/errors.py
  6. 87
      discord/app_commands/models.py
  7. 4
      discord/app_commands/transformers.py
  8. 53
      discord/app_commands/tree.py
  9. 53
      discord/asset.py
  10. 25
      discord/audit_logs.py
  11. 22
      discord/channel.py
  12. 10
      discord/client.py
  13. 13
      discord/colour.py
  14. 10
      discord/components.py
  15. 10
      discord/context_managers.py
  16. 12
      discord/embeds.py
  17. 4
      discord/emoji.py
  18. 44
      discord/enums.py
  19. 18
      discord/ext/commands/_types.py
  20. 88
      discord/ext/commands/bot.py
  21. 16
      discord/ext/commands/cog.py
  22. 15
      discord/ext/commands/context.py
  23. 86
      discord/ext/commands/converter.py
  24. 2
      discord/ext/commands/cooldowns.py
  25. 115
      discord/ext/commands/core.py
  26. 14
      discord/ext/commands/errors.py
  27. 18
      discord/ext/commands/flags.py
  28. 293
      discord/ext/commands/help.py
  29. 37
      discord/ext/commands/view.py
  30. 8
      discord/ext/tasks/__init__.py
  31. 2
      discord/file.py
  32. 20
      discord/flags.py
  33. 30
      discord/gateway.py
  34. 5
      discord/guild.py
  35. 22
      discord/http.py
  36. 17
      discord/integrations.py
  37. 7
      discord/interactions.py
  38. 4
      discord/invite.py
  39. 13
      discord/member.py
  40. 8
      discord/mentions.py
  41. 17
      discord/message.py
  42. 5
      discord/opus.py
  43. 15
      discord/partial_emoji.py
  44. 4
      discord/permissions.py
  45. 21
      discord/player.py
  46. 4
      discord/reaction.py
  47. 4
      discord/role.py
  48. 6
      discord/scheduled_event.py
  49. 5
      discord/shard.py
  50. 6
      discord/stage_instance.py
  51. 58
      discord/state.py
  52. 2
      discord/sticker.py
  53. 71
      discord/threads.py
  54. 1
      discord/types/activity.py
  55. 5
      discord/types/widget.py
  56. 17
      discord/ui/button.py
  57. 4
      discord/ui/item.py
  58. 8
      discord/ui/modal.py
  59. 18
      discord/ui/select.py
  60. 30
      discord/ui/view.py
  61. 6
      discord/user.py
  62. 36
      discord/utils.py
  63. 10
      discord/voice_client.py
  64. 122
      discord/webhook/async_.py
  65. 62
      discord/webhook/sync.py
  66. 2
      discord/widget.py

21
discord/__main__.py

@ -23,7 +23,8 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Dict, Optional
from typing import Optional, Tuple, Dict
import argparse
import sys
@ -35,7 +36,7 @@ import aiohttp
import platform
def show_version():
def show_version() -> None:
entries = []
entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info))
@ -52,7 +53,7 @@ def show_version():
print('\n'.join(entries))
def core(parser, args):
def core(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
if args.version:
show_version()
@ -185,7 +186,7 @@ _base_table.update((chr(i), None) for i in range(32))
_translation_table = str.maketrans(_base_table)
def to_path(parser, name, *, replace_spaces=False):
def to_path(parser: argparse.ArgumentParser, name: str, *, replace_spaces: bool = False) -> Path:
if isinstance(name, Path):
return name
@ -223,7 +224,7 @@ def to_path(parser, name, *, replace_spaces=False):
return Path(name)
def newbot(parser, args):
def newbot(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
new_directory = to_path(parser, args.directory) / to_path(parser, args.name)
# as a note exist_ok for Path is a 3.5+ only feature
@ -265,7 +266,7 @@ def newbot(parser, args):
print('successfully made bot at', new_directory)
def newcog(parser, args):
def newcog(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
cog_dir = to_path(parser, args.directory)
try:
cog_dir.mkdir(exist_ok=True)
@ -299,7 +300,7 @@ def newcog(parser, args):
print('successfully made cog at', directory)
def add_newbot_args(subparser):
def add_newbot_args(subparser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
parser = subparser.add_parser('newbot', help='creates a command bot project quickly')
parser.set_defaults(func=newbot)
@ -310,7 +311,7 @@ def add_newbot_args(subparser):
parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git')
def add_newcog_args(subparser):
def add_newcog_args(subparser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
parser = subparser.add_parser('newcog', help='creates a new cog template quickly')
parser.set_defaults(func=newcog)
@ -322,7 +323,7 @@ def add_newcog_args(subparser):
parser.add_argument('--full', help='add all special methods as well', action='store_true')
def parse_args():
def parse_args() -> Tuple[argparse.ArgumentParser, argparse.Namespace]:
parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py')
parser.add_argument('-v', '--version', action='store_true', help='shows the library version')
parser.set_defaults(func=core)
@ -333,7 +334,7 @@ def parse_args():
return parser, parser.parse_args()
def main():
def main() -> None:
parser, args = parse_args()
args.func(parser, args)

56
discord/abc.py

@ -91,6 +91,9 @@ if TYPE_CHECKING:
GuildChannel as GuildChannelPayload,
OverwriteType,
)
from .types.snowflake import (
SnowflakeList,
)
PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable]
MessageableChannel = Union[PartialMessageableChannel, GroupChannel]
@ -708,7 +711,14 @@ class GuildChannel:
) -> None:
...
async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions):
async def set_permissions(
self,
target: Union[Member, Role],
*,
overwrite: Any = _undefined,
reason: Optional[str] = None,
**permissions: bool,
) -> None:
r"""|coro|
Sets the channel specific permission overwrites for a target in the
@ -917,7 +927,7 @@ class GuildChannel:
) -> None:
...
async def move(self, **kwargs) -> None:
async def move(self, **kwargs: Any) -> None:
"""|coro|
A rich interface to help move a channel relative to other channels.
@ -1248,22 +1258,22 @@ class Messageable:
async def send(
self,
content=None,
content: Optional[str] = None,
*,
tts=False,
embed=None,
embeds=None,
file=None,
files=None,
stickers=None,
delete_after=None,
nonce=None,
allowed_mentions=None,
reference=None,
mention_author=None,
view=None,
suppress_embeds=False,
):
tts: bool = False,
embed: Optional[Embed] = None,
embeds: Optional[List[Embed]] = None,
file: Optional[File] = None,
files: Optional[List[File]] = None,
stickers: Optional[Sequence[Union[GuildSticker, StickerItem]]] = None,
delete_after: Optional[float] = None,
nonce: Optional[Union[str, int]] = None,
allowed_mentions: Optional[AllowedMentions] = None,
reference: Optional[Union[Message, MessageReference, PartialMessage]] = None,
mention_author: Optional[bool] = None,
view: Optional[View] = None,
suppress_embeds: bool = False,
) -> Message:
"""|coro|
Sends a message to the destination with the content given.
@ -1368,17 +1378,17 @@ class Messageable:
previous_allowed_mention = state.allowed_mentions
if stickers is not None:
stickers = [sticker.id for sticker in stickers]
sticker_ids: SnowflakeList = [sticker.id for sticker in stickers]
else:
stickers = MISSING
sticker_ids = MISSING
if reference is not None:
try:
reference = reference.to_message_reference_dict()
reference_dict = reference.to_message_reference_dict()
except AttributeError:
raise TypeError('reference parameter must be Message, MessageReference, or PartialMessage') from None
else:
reference = MISSING
reference_dict = MISSING
if view and not hasattr(view, '__discord_ui_view__'):
raise TypeError(f'view parameter must be View not {view.__class__!r}')
@ -1399,10 +1409,10 @@ class Messageable:
embeds=embeds if embeds is not None else MISSING,
nonce=nonce,
allowed_mentions=allowed_mentions,
message_reference=reference,
message_reference=reference_dict,
previous_allowed_mentions=previous_allowed_mention,
mention_author=mention_author,
stickers=stickers,
stickers=sticker_ids,
view=view,
flags=flags,
) as params:

42
discord/activity.py

@ -123,7 +123,7 @@ class BaseActivity:
__slots__ = ('_created_at',)
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
self._created_at: Optional[float] = kwargs.pop('created_at', None)
@property
@ -218,7 +218,7 @@ class Activity(BaseActivity):
'buttons',
)
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.state: Optional[str] = kwargs.pop('state', None)
self.details: Optional[str] = kwargs.pop('details', None)
@ -363,7 +363,7 @@ class Game(BaseActivity):
__slots__ = ('name', '_end', '_start')
def __init__(self, name: str, **extra):
def __init__(self, name: str, **extra: Any) -> None:
super().__init__(**extra)
self.name: str = name
@ -420,10 +420,10 @@ class Game(BaseActivity):
}
# fmt: on
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, Game) and other.name == self.name
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
@ -477,7 +477,7 @@ class Streaming(BaseActivity):
__slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets')
def __init__(self, *, name: Optional[str], url: str, **extra: Any):
def __init__(self, *, name: Optional[str], url: str, **extra: Any) -> None:
super().__init__(**extra)
self.platform: Optional[str] = name
self.name: Optional[str] = extra.pop('details', name)
@ -501,7 +501,7 @@ class Streaming(BaseActivity):
return f'<Streaming name={self.name!r}>'
@property
def twitch_name(self):
def twitch_name(self) -> Optional[str]:
"""Optional[:class:`str`]: If provided, the twitch name of the user streaming.
This corresponds to the ``large_image`` key of the :attr:`Streaming.assets`
@ -528,10 +528,10 @@ class Streaming(BaseActivity):
ret['details'] = self.details
return ret
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, Streaming) and other.name == self.name and other.url == self.url
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
@ -563,14 +563,14 @@ class Spotify:
__slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at')
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
self._state: str = data.pop('state', '')
self._details: str = data.pop('details', '')
self._timestamps: Dict[str, int] = data.pop('timestamps', {})
self._timestamps: ActivityTimestamps = data.pop('timestamps', {})
self._assets: ActivityAssets = data.pop('assets', {})
self._party: ActivityParty = data.pop('party', {})
self._sync_id: str = data.pop('sync_id')
self._session_id: str = data.pop('session_id')
self._sync_id: str = data.pop('sync_id', '')
self._session_id: Optional[str] = data.pop('session_id')
self._created_at: Optional[float] = data.pop('created_at', None)
@property
@ -622,7 +622,7 @@ class Spotify:
""":class:`str`: The activity's name. This will always return "Spotify"."""
return 'Spotify'
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return (
isinstance(other, Spotify)
and other._session_id == self._session_id
@ -630,7 +630,7 @@ class Spotify:
and other.start == self.start
)
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
@ -691,12 +691,14 @@ class Spotify:
@property
def start(self) -> datetime.datetime:
""":class:`datetime.datetime`: When the user started playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc)
# the start key will be present here
return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc) # type: ignore
@property
def end(self) -> datetime.datetime:
""":class:`datetime.datetime`: When the user will stop playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc)
# the end key will be present here
return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc) # type: ignore
@property
def duration(self) -> datetime.timedelta:
@ -742,7 +744,7 @@ class CustomActivity(BaseActivity):
__slots__ = ('name', 'emoji', 'state')
def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any):
def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any) -> None:
super().__init__(**extra)
self.name: Optional[str] = name
self.state: Optional[str] = extra.pop('state', None)
@ -786,10 +788,10 @@ class CustomActivity(BaseActivity):
o['emoji'] = self.emoji.to_dict()
return o
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:

22
discord/app_commands/commands.py

@ -166,7 +166,7 @@ def _validate_auto_complete_callback(
return callback
def _context_menu_annotation(annotation: Any, *, _none=NoneType) -> AppCommandType:
def _context_menu_annotation(annotation: Any, *, _none: type = NoneType) -> AppCommandType:
if annotation is Message:
return AppCommandType.message
@ -686,7 +686,7 @@ class Group:
The parent group. ``None`` if there isn't one.
"""
__discord_app_commands_group_children__: ClassVar[List[Union[Command, Group]]] = []
__discord_app_commands_group_children__: ClassVar[List[Union[Command[Any, ..., Any], Group]]] = []
__discord_app_commands_skip_init_binding__: bool = False
__discord_app_commands_group_name__: str = MISSING
__discord_app_commands_group_description__: str = MISSING
@ -694,10 +694,12 @@ class Group:
def __init_subclass__(cls, *, name: str = MISSING, description: str = MISSING) -> None:
if not cls.__discord_app_commands_group_children__:
cls.__discord_app_commands_group_children__ = children = [
children: List[Union[Command[Any, ..., Any], Group]] = [
member for member in cls.__dict__.values() if isinstance(member, (Group, Command)) and member.parent is None
]
cls.__discord_app_commands_group_children__ = children
found = set()
for child in children:
if child.name in found:
@ -796,15 +798,15 @@ class Group:
"""Optional[:class:`Group`]: The parent of this group."""
return self.parent
def _get_internal_command(self, name: str) -> Optional[Union[Command, Group]]:
def _get_internal_command(self, name: str) -> Optional[Union[Command[Any, ..., Any], Group]]:
return self._children.get(name)
@property
def commands(self) -> List[Union[Command, Group]]:
def commands(self) -> List[Union[Command[Any, ..., Any], Group]]:
"""List[Union[:class:`Command`, :class:`Group`]]: The commands that this group contains."""
return list(self._children.values())
async def on_error(self, interaction: Interaction, command: Command, error: AppCommandError) -> None:
async def on_error(self, interaction: Interaction, command: Command[Any, ..., Any], error: AppCommandError) -> None:
"""|coro|
A callback that is called when a child's command raises an :exc:`AppCommandError`.
@ -823,7 +825,7 @@ class Group:
pass
def add_command(self, command: Union[Command, Group], /, *, override: bool = False):
def add_command(self, command: Union[Command[Any, ..., Any], Group], /, *, override: bool = False) -> None:
"""Adds a command or group to this group's internal list of commands.
Parameters
@ -855,7 +857,7 @@ class Group:
if len(self._children) > 25:
raise ValueError('maximum number of child commands exceeded')
def remove_command(self, name: str, /) -> Optional[Union[Command, Group]]:
def remove_command(self, name: str, /) -> Optional[Union[Command[Any, ..., Any], Group]]:
"""Removes a command or group from the internal list of commands.
Parameters
@ -872,7 +874,7 @@ class Group:
self._children.pop(name, None)
def get_command(self, name: str, /) -> Optional[Union[Command, Group]]:
def get_command(self, name: str, /) -> Optional[Union[Command[Any, ..., Any], Group]]:
"""Retrieves a command or group from its name.
Parameters
@ -1046,7 +1048,7 @@ def describe(**parameters: str) -> Callable[[T], T]:
return decorator
def choices(**parameters: List[Choice]) -> Callable[[T], T]:
def choices(**parameters: List[Choice[ChoiceT]]) -> Callable[[T], T]:
r"""Instructs the given parameters by their name to use the given choices for their choices.
Example:

8
discord/app_commands/errors.py

@ -79,9 +79,9 @@ class CommandInvokeError(AppCommandError):
The command that failed.
"""
def __init__(self, command: Union[Command, ContextMenu], e: Exception) -> None:
def __init__(self, command: Union[Command[Any, ..., Any], ContextMenu], e: Exception) -> None:
self.original: Exception = e
self.command: Union[Command, ContextMenu] = command
self.command: Union[Command[Any, ..., Any], ContextMenu] = command
super().__init__(f'Command {command.name!r} raised an exception: {e.__class__.__name__}: {e}')
@ -191,8 +191,8 @@ class CommandSignatureMismatch(AppCommandError):
The command that had the signature mismatch.
"""
def __init__(self, command: Union[Command, ContextMenu, Group]):
self.command: Union[Command, ContextMenu, Group] = command
def __init__(self, command: Union[Command[Any, ..., Any], ContextMenu, Group]):
self.command: Union[Command[Any, ..., Any], ContextMenu, Group] = command
msg = (
f'The signature for command {command.name!r} is different from the one provided by Discord. '
'This can happen because either your code is out of date or you have not synced the '

87
discord/app_commands/models.py

@ -58,7 +58,10 @@ if TYPE_CHECKING:
PartialChannel,
PartialThread,
)
from ..types.threads import ThreadMetadata
from ..types.threads import (
ThreadMetadata,
ThreadArchiveDuration,
)
from ..state import ConnectionState
from ..guild import GuildChannel, Guild
from ..channel import TextChannel
@ -117,17 +120,19 @@ class AppCommand(Hashable):
'_state',
)
def __init__(self, *, data: ApplicationCommandPayload, state=None):
self._state = state
def __init__(self, *, data: ApplicationCommandPayload, state: Optional[ConnectionState] = None) -> None:
self._state: Optional[ConnectionState] = state
self._from_data(data)
def _from_data(self, data: ApplicationCommandPayload):
def _from_data(self, data: ApplicationCommandPayload) -> None:
self.id: int = int(data['id'])
self.application_id: int = int(data['application_id'])
self.name: str = data['name']
self.description: str = data['description']
self.type: AppCommandType = try_enum(AppCommandType, data.get('type', 1))
self.options = [app_command_option_factory(data=d, parent=self, state=self._state) for d in data.get('options', [])]
self.options: List[Union[Argument, AppCommandGroup]] = [
app_command_option_factory(data=d, parent=self, state=self._state) for d in data.get('options', [])
]
def to_dict(self) -> ApplicationCommandPayload:
return {
@ -262,12 +267,12 @@ class AppCommandChannel(Hashable):
data: PartialChannel,
guild_id: int,
):
self._state = state
self.guild_id = guild_id
self.id = int(data['id'])
self.type = try_enum(ChannelType, data['type'])
self.name = data['name']
self.permissions = Permissions(int(data['permissions']))
self._state: ConnectionState = state
self.guild_id: int = guild_id
self.id: int = int(data['id'])
self.type: ChannelType = try_enum(ChannelType, data['type'])
self.name: str = data['name']
self.permissions: Permissions = Permissions(int(data['permissions']))
def __str__(self) -> str:
return self.name
@ -405,13 +410,13 @@ class AppCommandThread(Hashable):
data: PartialThread,
guild_id: int,
):
self._state = state
self.guild_id = guild_id
self.id = int(data['id'])
self.parent_id = int(data['parent_id'])
self.type = try_enum(ChannelType, data['type'])
self.name = data['name']
self.permissions = Permissions(int(data['permissions']))
self._state: ConnectionState = state
self.guild_id: int = guild_id
self.id: int = int(data['id'])
self.parent_id: int = int(data['parent_id'])
self.type: ChannelType = try_enum(ChannelType, data['type'])
self.name: str = data['name']
self.permissions: Permissions = Permissions(int(data['permissions']))
self._unroll_metadata(data['thread_metadata'])
def __str__(self) -> str:
@ -425,14 +430,14 @@ class AppCommandThread(Hashable):
"""Optional[:class:`~discord.Guild`]: The channel's guild, from cache, if found."""
return self._state._get_guild(self.guild_id)
def _unroll_metadata(self, data: ThreadMetadata):
self.archived = data['archived']
self.archiver_id = _get_as_snowflake(data, 'archiver_id')
self.auto_archive_duration = data['auto_archive_duration']
self.archive_timestamp = parse_time(data['archive_timestamp'])
self.locked = data.get('locked', False)
self.invitable = data.get('invitable', True)
self._created_at = parse_time(data.get('create_timestamp'))
def _unroll_metadata(self, data: ThreadMetadata) -> None:
self.archived: bool = data['archived']
self.archiver_id: Optional[int] = _get_as_snowflake(data, 'archiver_id')
self.auto_archive_duration: ThreadArchiveDuration = data['auto_archive_duration']
self.archive_timestamp: datetime = parse_time(data['archive_timestamp'])
self.locked: bool = data.get('locked', False)
self.invitable: bool = data.get('invitable', True)
self._created_at: Optional[datetime] = parse_time(data.get('create_timestamp'))
@property
def parent(self) -> Optional[TextChannel]:
@ -522,20 +527,24 @@ class Argument:
'_state',
)
def __init__(self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state=None):
self._state = state
self.parent = parent
def __init__(
self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state: Optional[ConnectionState] = None
) -> None:
self._state: Optional[ConnectionState] = state
self.parent: ApplicationCommandParent = parent
self._from_data(data)
def __repr__(self) -> str:
return f'<{self.__class__.__name__} name={self.name!r} type={self.type!r} required={self.required}>'
def _from_data(self, data: ApplicationCommandOption):
def _from_data(self, data: ApplicationCommandOption) -> None:
self.type: AppCommandOptionType = try_enum(AppCommandOptionType, data['type'])
self.name: str = data['name']
self.description: str = data['description']
self.required: bool = data.get('required', False)
self.choices: List[Choice] = [Choice(name=d['name'], value=d['value']) for d in data.get('choices', [])]
self.choices: List[Choice[Union[int, float, str]]] = [
Choice(name=d['name'], value=d['value']) for d in data.get('choices', [])
]
def to_dict(self) -> ApplicationCommandOption:
return {
@ -582,20 +591,24 @@ class AppCommandGroup:
'_state',
)
def __init__(self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state=None):
self.parent = parent
self._state = state
def __init__(
self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state: Optional[ConnectionState] = None
) -> None:
self.parent: ApplicationCommandParent = parent
self._state: Optional[ConnectionState] = state
self._from_data(data)
def __repr__(self) -> str:
return f'<{self.__class__.__name__} name={self.name!r} type={self.type!r} required={self.required}>'
def _from_data(self, data: ApplicationCommandOption):
def _from_data(self, data: ApplicationCommandOption) -> None:
self.type: AppCommandOptionType = try_enum(AppCommandOptionType, data['type'])
self.name: str = data['name']
self.description: str = data['description']
self.required: bool = data.get('required', False)
self.choices: List[Choice] = [Choice(name=d['name'], value=d['value']) for d in data.get('choices', [])]
self.choices: List[Choice[Union[int, float, str]]] = [
Choice(name=d['name'], value=d['value']) for d in data.get('choices', [])
]
self.arguments: List[Argument] = [
Argument(parent=self, state=self._state, data=d)
for d in data.get('options', [])
@ -614,7 +627,7 @@ class AppCommandGroup:
def app_command_option_factory(
parent: ApplicationCommandParent, data: ApplicationCommandOption, *, state=None
parent: ApplicationCommandParent, data: ApplicationCommandOption, *, state: Optional[ConnectionState] = None
) -> Union[Argument, AppCommandGroup]:
if is_app_command_argument_type(data['type']):
return Argument(parent=parent, data=data, state=state)

4
discord/app_commands/transformers.py

@ -95,7 +95,7 @@ class CommandParameter:
description: str = MISSING
required: bool = MISSING
default: Any = MISSING
choices: List[Choice] = MISSING
choices: List[Choice[Union[str, int, float]]] = MISSING
type: AppCommandOptionType = MISSING
channel_types: List[ChannelType] = MISSING
min_value: Optional[Union[int, float]] = None
@ -549,7 +549,7 @@ ALLOWED_DEFAULTS: Dict[AppCommandOptionType, Tuple[Type[Any], ...]] = {
def get_supported_annotation(
annotation: Any,
*,
_none=NoneType,
_none: type = NoneType,
_mapping: Dict[Any, Type[Transformer]] = BUILT_IN_TRANSFORMERS,
) -> Tuple[Any, Any]:
"""Returns an appropriate, yet supported, annotation along with an optional default value.

53
discord/app_commands/tree.py

@ -26,7 +26,22 @@ from __future__ import annotations
import inspect
import sys
import traceback
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, TYPE_CHECKING, Set, Tuple, TypeVar, Union, overload
from typing import (
Any,
TYPE_CHECKING,
Callable,
Dict,
Generic,
List,
Literal,
Optional,
Set,
Tuple,
TypeVar,
Union,
overload,
)
from collections import Counter
@ -194,13 +209,13 @@ class CommandTree(Generic[ClientT]):
def add_command(
self,
command: Union[Command, ContextMenu, Group],
command: Union[Command[Any, ..., Any], ContextMenu, Group],
/,
*,
guild: Optional[Snowflake] = MISSING,
guilds: List[Snowflake] = MISSING,
override: bool = False,
):
) -> None:
"""Adds an application command to the tree.
This only adds the command locally -- in order to sync the commands
@ -317,7 +332,7 @@ class CommandTree(Generic[ClientT]):
*,
guild: Optional[Snowflake] = ...,
type: Literal[AppCommandType.chat_input] = ...,
) -> Optional[Union[Command, Group]]:
) -> Optional[Union[Command[Any, ..., Any], Group]]:
...
@overload
@ -328,7 +343,7 @@ class CommandTree(Generic[ClientT]):
*,
guild: Optional[Snowflake] = ...,
type: AppCommandType = ...,
) -> Optional[Union[Command, ContextMenu, Group]]:
) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]:
...
def remove_command(
@ -338,7 +353,7 @@ class CommandTree(Generic[ClientT]):
*,
guild: Optional[Snowflake] = None,
type: AppCommandType = AppCommandType.chat_input,
) -> Optional[Union[Command, ContextMenu, Group]]:
) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]:
"""Removes an application command from the tree.
This only removes the command locally -- in order to sync the commands
@ -396,7 +411,7 @@ class CommandTree(Generic[ClientT]):
*,
guild: Optional[Snowflake] = ...,
type: Literal[AppCommandType.chat_input] = ...,
) -> Optional[Union[Command, Group]]:
) -> Optional[Union[Command[Any, ..., Any], Group]]:
...
@overload
@ -407,7 +422,7 @@ class CommandTree(Generic[ClientT]):
*,
guild: Optional[Snowflake] = ...,
type: AppCommandType = ...,
) -> Optional[Union[Command, ContextMenu, Group]]:
) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]:
...
def get_command(
@ -417,7 +432,7 @@ class CommandTree(Generic[ClientT]):
*,
guild: Optional[Snowflake] = None,
type: AppCommandType = AppCommandType.chat_input,
) -> Optional[Union[Command, ContextMenu, Group]]:
) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]:
"""Gets a application command from the tree.
Parameters
@ -468,7 +483,7 @@ class CommandTree(Generic[ClientT]):
*,
guild: Optional[Snowflake] = ...,
type: Literal[AppCommandType.chat_input] = ...,
) -> List[Union[Command, Group]]:
) -> List[Union[Command[Any, ..., Any], Group]]:
...
@overload
@ -477,7 +492,7 @@ class CommandTree(Generic[ClientT]):
*,
guild: Optional[Snowflake] = ...,
type: AppCommandType = ...,
) -> Union[List[Union[Command, Group]], List[ContextMenu]]:
) -> Union[List[Union[Command[Any, ..., Any], Group]], List[ContextMenu]]:
...
def get_commands(
@ -485,7 +500,7 @@ class CommandTree(Generic[ClientT]):
*,
guild: Optional[Snowflake] = None,
type: AppCommandType = AppCommandType.chat_input,
) -> Union[List[Union[Command, Group]], List[ContextMenu]]:
) -> Union[List[Union[Command[Any, ..., Any], Group]], List[ContextMenu]]:
"""Gets all application commands from the tree.
Parameters
@ -518,9 +533,11 @@ class CommandTree(Generic[ClientT]):
value = type.value
return [command for ((_, g, t), command) in self._context_menus.items() if g == guild_id and t == value]
def _get_all_commands(self, *, guild: Optional[Snowflake] = None) -> List[Union[Command, Group, ContextMenu]]:
def _get_all_commands(
self, *, guild: Optional[Snowflake] = None
) -> List[Union[Command[Any, ..., Any], Group, ContextMenu]]:
if guild is None:
base: List[Union[Command, Group, ContextMenu]] = list(self._global_commands.values())
base: List[Union[Command[Any, ..., Any], Group, ContextMenu]] = list(self._global_commands.values())
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g is None)
return base
else:
@ -530,7 +547,7 @@ class CommandTree(Generic[ClientT]):
guild_id = guild.id
return [cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id]
else:
base: List[Union[Command, Group, ContextMenu]] = list(commands.values())
base: List[Union[Command[Any, ..., Any], Group, ContextMenu]] = list(commands.values())
guild_id = guild.id
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id)
return base
@ -564,7 +581,7 @@ class CommandTree(Generic[ClientT]):
async def on_error(
self,
interaction: Interaction,
command: Optional[Union[ContextMenu, Command]],
command: Optional[Union[ContextMenu, Command[Any, ..., Any]]],
error: AppCommandError,
) -> None:
"""|coro|
@ -742,7 +759,7 @@ class CommandTree(Generic[ClientT]):
self.client.loop.create_task(wrapper(), name='CommandTree-invoker')
async def _call_context_menu(self, interaction: Interaction, data: ApplicationCommandInteractionData, type: int):
async def _call_context_menu(self, interaction: Interaction, data: ApplicationCommandInteractionData, type: int) -> None:
name = data['name']
guild_id = _get_as_snowflake(data, 'guild_id')
ctx_menu = self._context_menus.get((name, guild_id, type))
@ -770,7 +787,7 @@ class CommandTree(Generic[ClientT]):
except AppCommandError as e:
await self.on_error(interaction, ctx_menu, e)
async def call(self, interaction: Interaction):
async def call(self, interaction: Interaction) -> None:
"""|coro|
Given an :class:`~discord.Interaction`, calls the matching

53
discord/asset.py

@ -39,6 +39,13 @@ __all__ = (
# fmt: on
if TYPE_CHECKING:
from typing_extensions import Self
from .state import ConnectionState
from .webhook.async_ import _WebhookState
_State = Union[ConnectionState, _WebhookState]
ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png']
ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif']
@ -77,7 +84,7 @@ class AssetMixin:
return await self._state.http.get_from_cdn(self.url)
async def save(self, fp: Union[str, bytes, os.PathLike, io.BufferedIOBase], *, seek_begin: bool = True) -> int:
async def save(self, fp: Union[str, bytes, os.PathLike[Any], io.BufferedIOBase], *, seek_begin: bool = True) -> int:
"""|coro|
Saves this asset into a file-like object.
@ -153,14 +160,14 @@ class Asset(AssetMixin):
BASE = 'https://cdn.discordapp.com'
def __init__(self, state, *, url: str, key: str, animated: bool = False):
self._state = state
self._url = url
self._animated = animated
self._key = key
def __init__(self, state: _State, *, url: str, key: str, animated: bool = False) -> None:
self._state: _State = state
self._url: str = url
self._animated: bool = animated
self._key: str = key
@classmethod
def _from_default_avatar(cls, state, index: int) -> Asset:
def _from_default_avatar(cls, state: _State, index: int) -> Self:
return cls(
state,
url=f'{cls.BASE}/embed/avatars/{index}.png',
@ -169,7 +176,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset:
def _from_avatar(cls, state: _State, user_id: int, avatar: str) -> Self:
animated = avatar.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
@ -180,7 +187,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset:
def _from_guild_avatar(cls, state: _State, guild_id: int, member_id: int, avatar: str) -> Self:
animated = avatar.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
@ -191,7 +198,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
def _from_icon(cls, state: _State, object_id: int, icon_hash: str, path: str) -> Self:
return cls(
state,
url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024',
@ -200,7 +207,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset:
def _from_cover_image(cls, state: _State, object_id: int, cover_image_hash: str) -> Self:
return cls(
state,
url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024',
@ -209,7 +216,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_scheduled_event_cover_image(cls, state, scheduled_event_id: int, cover_image_hash: str) -> Asset:
def _from_scheduled_event_cover_image(cls, state: _State, scheduled_event_id: int, cover_image_hash: str) -> Self:
return cls(
state,
url=f'{cls.BASE}/guild-events/{scheduled_event_id}/{cover_image_hash}.png?size=1024',
@ -218,7 +225,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset:
def _from_guild_image(cls, state: _State, guild_id: int, image: str, path: str) -> Self:
animated = image.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
@ -229,7 +236,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset:
def _from_guild_icon(cls, state: _State, guild_id: int, icon_hash: str) -> Self:
animated = icon_hash.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
@ -240,7 +247,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_sticker_banner(cls, state, banner: int) -> Asset:
def _from_sticker_banner(cls, state: _State, banner: int) -> Self:
return cls(
state,
url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png',
@ -249,7 +256,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:
def _from_user_banner(cls, state: _State, user_id: int, banner_hash: str) -> Self:
animated = banner_hash.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
@ -265,14 +272,14 @@ class Asset(AssetMixin):
def __len__(self) -> int:
return len(self._url)
def __repr__(self):
def __repr__(self) -> str:
shorten = self._url.replace(self.BASE, '')
return f'<Asset url={shorten!r}>'
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
return isinstance(other, Asset) and self._url == other._url
def __hash__(self):
def __hash__(self) -> int:
return hash(self._url)
@property
@ -295,7 +302,7 @@ class Asset(AssetMixin):
size: int = MISSING,
format: ValidAssetFormatTypes = MISSING,
static_format: ValidStaticFormatTypes = MISSING,
) -> Asset:
) -> Self:
"""Returns a new asset with the passed components replaced.
.. versionchanged:: 2.0
@ -350,7 +357,7 @@ class Asset(AssetMixin):
url = str(url)
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_size(self, size: int, /) -> Asset:
def with_size(self, size: int, /) -> Self:
"""Returns a new asset with the specified size.
.. versionchanged:: 2.0
@ -378,7 +385,7 @@ class Asset(AssetMixin):
url = str(yarl.URL(self._url).with_query(size=size))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_format(self, format: ValidAssetFormatTypes, /) -> Asset:
def with_format(self, format: ValidAssetFormatTypes, /) -> Self:
"""Returns a new asset with the specified format.
.. versionchanged:: 2.0
@ -413,7 +420,7 @@ class Asset(AssetMixin):
url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset:
def with_static_format(self, format: ValidStaticFormatTypes, /) -> Self:
"""Returns a new asset with the specified static format.
This only changes the format if the underlying asset is

25
discord/audit_logs.py

@ -50,12 +50,12 @@ if TYPE_CHECKING:
from .member import Member
from .role import Role
from .scheduled_event import ScheduledEvent
from .state import ConnectionState
from .types.audit_log import (
AuditLogChange as AuditLogChangePayload,
AuditLogEntry as AuditLogEntryPayload,
)
from .types.channel import (
PartialChannel as PartialChannelPayload,
PermissionOverwrite as PermissionOverwritePayload,
)
from .types.invite import Invite as InvitePayload
@ -242,8 +242,8 @@ class AuditLogChanges:
# fmt: on
def __init__(self, entry: AuditLogEntry, data: List[AuditLogChangePayload]):
self.before = AuditLogDiff()
self.after = AuditLogDiff()
self.before: AuditLogDiff = AuditLogDiff()
self.after: AuditLogDiff = AuditLogDiff()
for elem in data:
attr = elem['key']
@ -390,17 +390,17 @@ class AuditLogEntry(Hashable):
"""
def __init__(self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild):
self._state = guild._state
self.guild = guild
self._users = users
self._state: ConnectionState = guild._state
self.guild: Guild = guild
self._users: Dict[int, User] = users
self._from_data(data)
def _from_data(self, data: AuditLogEntryPayload) -> None:
self.action = enums.try_enum(enums.AuditLogAction, data['action_type'])
self.id = int(data['id'])
self.action: enums.AuditLogAction = enums.try_enum(enums.AuditLogAction, data['action_type'])
self.id: int = int(data['id'])
# this key is technically not usually present
self.reason = data.get('reason')
self.reason: Optional[str] = data.get('reason')
extra = data.get('options')
# fmt: off
@ -464,10 +464,13 @@ class AuditLogEntry(Hashable):
self._changes = data.get('changes', [])
user_id = utils._get_as_snowflake(data, 'user_id')
self.user = user_id and self._get_member(user_id)
self.user: Optional[Union[User, Member]] = self._get_member(user_id)
self._target_id = utils._get_as_snowflake(data, 'target_id')
def _get_member(self, user_id: int) -> Union[Member, User, None]:
def _get_member(self, user_id: Optional[int]) -> Union[Member, User, None]:
if user_id is None:
return None
return self.guild.get_member(user_id) or self._users.get(user_id)
def __repr__(self) -> str:

22
discord/channel.py

@ -198,7 +198,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id')
self._fill_overwrites(data)
async def _get_channel(self):
async def _get_channel(self) -> Self:
return self
@property
@ -283,7 +283,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
async def edit(self) -> Optional[TextChannel]:
...
async def edit(self, *, reason=None, **options):
async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[TextChannel]:
"""|coro|
Edits the channel.
@ -908,7 +908,7 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha
return self.guild.id, self.id
def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None:
self.guild = guild
self.guild: Guild = guild
self.name: str = data['name']
self.rtc_region: Optional[str] = data.get('rtc_region')
self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1))
@ -1076,7 +1076,7 @@ class VoiceChannel(VocalGuildChannel):
async def edit(self) -> Optional[VoiceChannel]:
...
async def edit(self, *, reason=None, **options):
async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[VoiceChannel]:
"""|coro|
Edits the channel.
@ -1220,7 +1220,7 @@ class StageChannel(VocalGuildChannel):
def _update(self, guild: Guild, data: StageChannelPayload) -> None:
super()._update(guild, data)
self.topic = data.get('topic')
self.topic: Optional[str] = data.get('topic')
@property
def requesting_to_speak(self) -> List[Member]:
@ -1361,7 +1361,7 @@ class StageChannel(VocalGuildChannel):
async def edit(self) -> Optional[StageChannel]:
...
async def edit(self, *, reason=None, **options):
async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[StageChannel]:
"""|coro|
Edits the channel.
@ -1522,7 +1522,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
async def edit(self) -> Optional[CategoryChannel]:
...
async def edit(self, *, reason=None, **options):
async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[CategoryChannel]:
"""|coro|
Edits the channel.
@ -1578,7 +1578,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
@utils.copy_doc(discord.abc.GuildChannel.move)
async def move(self, **kwargs):
async def move(self, **kwargs: Any) -> None:
kwargs.pop('category', None)
await super().move(**kwargs)
@ -1772,7 +1772,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
async def edit(self) -> Optional[StoreChannel]:
...
async def edit(self, *, reason=None, **options):
async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[StoreChannel]:
"""|coro|
Edits the channel.
@ -1874,7 +1874,7 @@ class DMChannel(discord.abc.Messageable, Hashable):
self.me: ClientUser = me
self.id: int = int(data['id'])
async def _get_channel(self):
async def _get_channel(self) -> Self:
return self
def __str__(self) -> str:
@ -2026,7 +2026,7 @@ class GroupChannel(discord.abc.Messageable, Hashable):
else:
self.owner = utils.find(lambda u: u.id == self.owner_id, self.recipients)
async def _get_channel(self):
async def _get_channel(self) -> Self:
return self
def __str__(self) -> str:

10
discord/client.py

@ -196,11 +196,11 @@ class Client:
unsync_clock: bool = options.pop('assume_unsync_clock', True)
self.http: HTTPClient = HTTPClient(self.loop, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock)
self._handlers: Dict[str, Callable] = {
self._handlers: Dict[str, Callable[..., None]] = {
'ready': self._handle_ready,
}
self._hooks: Dict[str, Callable] = {
self._hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = {
'before_identify': self._call_before_identify_hook,
}
@ -698,7 +698,7 @@ class Client:
raise TypeError('activity must derive from BaseActivity.')
@property
def status(self):
def status(self) -> Status:
""":class:`.Status`:
The status being used upon logging on to Discord.
@ -709,7 +709,7 @@ class Client:
return Status.online
@status.setter
def status(self, value):
def status(self, value: Status) -> None:
if value is Status.offline:
self._connection._status = 'invisible'
elif isinstance(value, Status):
@ -1077,7 +1077,7 @@ class Client:
*,
activity: Optional[BaseActivity] = None,
status: Optional[Status] = None,
):
) -> None:
"""|coro|
Changes the client's presence.

13
discord/colour.py

@ -32,7 +32,6 @@ from typing import (
Callable,
Optional,
Tuple,
Type,
Union,
)
@ -90,10 +89,10 @@ class Colour:
def _get_byte(self, byte: int) -> int:
return (self.value >> (8 * byte)) & 0xFF
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, Colour) and self.value == other.value
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __str__(self) -> str:
@ -265,28 +264,28 @@ class Colour:
"""A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``."""
return cls(0x95A5A6)
lighter_gray: Callable[[Type[Self]], Self] = lighter_grey
lighter_gray = lighter_grey
@classmethod
def dark_grey(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x607d8b``."""
return cls(0x607D8B)
dark_gray: Callable[[Type[Self]], Self] = dark_grey
dark_gray = dark_grey
@classmethod
def light_grey(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x979c9f``."""
return cls(0x979C9F)
light_gray: Callable[[Type[Self]], Self] = light_grey
light_gray = light_grey
@classmethod
def darker_grey(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x546e7a``."""
return cls(0x546E7A)
darker_gray: Callable[[Type[Self]], Self] = darker_grey
darker_gray = darker_grey
@classmethod
def og_blurple(cls) -> Self:

10
discord/components.py

@ -310,9 +310,9 @@ class SelectOption:
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
default: bool = False,
) -> None:
self.label = label
self.value = label if value is MISSING else value
self.description = description
self.label: str = label
self.value: str = label if value is MISSING else value
self.description: Optional[str] = description
if emoji is not None:
if isinstance(emoji, str):
@ -322,8 +322,8 @@ class SelectOption:
else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
self.emoji = emoji
self.default = default
self.emoji: Optional[Union[str, Emoji, PartialEmoji]] = emoji
self.default: bool = default
def __repr__(self) -> str:
return (

10
discord/context_managers.py

@ -25,13 +25,15 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING, Optional, Type
from typing import TYPE_CHECKING, Optional, Type, TypeVar
if TYPE_CHECKING:
from .abc import Messageable
from types import TracebackType
BE = TypeVar('BE', bound=BaseException)
# fmt: off
__all__ = (
'Typing',
@ -67,13 +69,13 @@ class Typing:
async def __aenter__(self) -> None:
self._channel = channel = await self.messageable._get_channel()
await channel._state.http.send_typing(channel.id)
self.task: asyncio.Task = self.loop.create_task(self.do_typing())
self.task: asyncio.Task[None] = self.loop.create_task(self.do_typing())
self.task.add_done_callback(_typing_done_callback)
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
exc_type: Optional[Type[BE]],
exc: Optional[BE],
traceback: Optional[TracebackType],
) -> None:
self.task.cancel()

12
discord/embeds.py

@ -189,10 +189,10 @@ class Embed:
):
self.colour = colour if colour is not EmptyEmbed else color
self.title = title
self.type = type
self.url = url
self.description = description
self.title: MaybeEmpty[str] = title
self.type: EmbedType = type
self.url: MaybeEmpty[str] = url
self.description: MaybeEmpty[str] = description
if self.title is not EmptyEmbed:
self.title = str(self.title)
@ -311,7 +311,7 @@ class Embed:
return getattr(self, '_colour', EmptyEmbed)
@colour.setter
def colour(self, value: Union[int, Colour, _EmptyEmbed]):
def colour(self, value: Union[int, Colour, _EmptyEmbed]) -> None:
if isinstance(value, (Colour, _EmptyEmbed)):
self._colour = value
elif isinstance(value, int):
@ -326,7 +326,7 @@ class Embed:
return getattr(self, '_timestamp', EmptyEmbed)
@timestamp.setter
def timestamp(self, value: MaybeEmpty[datetime.datetime]):
def timestamp(self, value: MaybeEmpty[datetime.datetime]) -> None:
if isinstance(value, datetime.datetime):
if value.tzinfo is None:
value = value.astimezone()

4
discord/emoji.py

@ -142,10 +142,10 @@ class Emoji(_EmojiTag, AssetMixin):
def __repr__(self) -> str:
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>'
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, _EmojiTag) and self.id == other.id
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:

44
discord/enums.py

@ -25,7 +25,7 @@ from __future__ import annotations
import types
from collections import namedtuple
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Iterator, Mapping
__all__ = (
'Enum',
@ -131,38 +131,38 @@ class EnumMeta(type):
value_cls._actual_enum_cls_ = actual_cls # type: ignore - Runtime attribute isn't understood
return actual_cls
def __iter__(cls):
def __iter__(cls) -> Iterator[Any]:
return (cls._enum_member_map_[name] for name in cls._enum_member_names_)
def __reversed__(cls):
def __reversed__(cls) -> Iterator[Any]:
return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_))
def __len__(cls):
def __len__(cls) -> int:
return len(cls._enum_member_names_)
def __repr__(cls):
def __repr__(cls) -> str:
return f'<enum {cls.__name__}>'
@property
def __members__(cls):
def __members__(cls) -> Mapping[str, Any]:
return types.MappingProxyType(cls._enum_member_map_)
def __call__(cls, value):
def __call__(cls, value: str) -> Any:
try:
return cls._enum_value_map_[value]
except (KeyError, TypeError):
raise ValueError(f"{value!r} is not a valid {cls.__name__}")
def __getitem__(cls, key):
def __getitem__(cls, key: str) -> Any:
return cls._enum_member_map_[key]
def __setattr__(cls, name, value):
def __setattr__(cls, name: str, value: Any) -> None:
raise TypeError('Enums are immutable.')
def __delattr__(cls, attr):
def __delattr__(cls, attr: str) -> None:
raise TypeError('Enums are immutable')
def __instancecheck__(self, instance):
def __instancecheck__(self, instance: Any) -> bool:
# isinstance(x, Y)
# -> __instancecheck__(Y, x)
try:
@ -197,7 +197,7 @@ class ChannelType(Enum):
private_thread = 12
stage_voice = 13
def __str__(self):
def __str__(self) -> str:
return self.name
@ -233,10 +233,10 @@ class SpeakingState(Enum):
soundshare = 2
priority = 4
def __str__(self):
def __str__(self) -> str:
return self.name
def __int__(self):
def __int__(self) -> int:
return self.value
@ -247,7 +247,7 @@ class VerificationLevel(Enum, comparable=True):
high = 3
highest = 4
def __str__(self):
def __str__(self) -> str:
return self.name
@ -256,7 +256,7 @@ class ContentFilter(Enum, comparable=True):
no_role = 1
all_members = 2
def __str__(self):
def __str__(self) -> str:
return self.name
@ -268,7 +268,7 @@ class Status(Enum):
do_not_disturb = 'dnd'
invisible = 'invisible'
def __str__(self):
def __str__(self) -> str:
return self.value
@ -280,7 +280,7 @@ class DefaultAvatar(Enum):
orange = 3
red = 4
def __str__(self):
def __str__(self) -> str:
return self.name
@ -467,7 +467,7 @@ class ActivityType(Enum):
custom = 4
competing = 5
def __int__(self):
def __int__(self) -> int:
return self.value
@ -542,7 +542,7 @@ class VideoQualityMode(Enum):
auto = 1
full = 2
def __int__(self):
def __int__(self) -> int:
return self.value
@ -552,7 +552,7 @@ class ComponentType(Enum):
select = 3
text_input = 4
def __int__(self):
def __int__(self) -> int:
return self.value
@ -571,7 +571,7 @@ class ButtonStyle(Enum):
red = 4
url = 5
def __int__(self):
def __int__(self) -> int:
return self.value

18
discord/ext/commands/_types.py

@ -23,21 +23,35 @@ DEALINGS IN THE SOFTWARE.
"""
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple
T = TypeVar('T')
if TYPE_CHECKING:
from typing_extensions import ParamSpec
from .bot import Bot, AutoShardedBot
from .context import Context
from .cog import Cog
from .errors import CommandError
T = TypeVar('T')
P = ParamSpec('P')
MaybeCoroFunc = Union[
Callable[P, 'Coro[T]'],
Callable[P, T],
]
else:
P = TypeVar('P')
MaybeCoroFunc = Tuple[P, T]
Coro = Coroutine[Any, Any, T]
MaybeCoro = Union[T, Coro[T]]
CoroFunc = Callable[..., Coro[Any]]
ContextT = TypeVar('ContextT', bound='Context')
_Bot = Union['Bot', 'AutoShardedBot']
BotT = TypeVar('BotT', bound=_Bot)
Check = Union[Callable[["Cog", "ContextT"], MaybeCoro[bool]], Callable[["ContextT"], MaybeCoro[bool]]]
Hook = Union[Callable[["Cog", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]]

88
discord/ext/commands/bot.py

@ -33,7 +33,21 @@ import importlib.util
import sys
import traceback
import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union, overload
from typing import (
Any,
Callable,
Mapping,
List,
Dict,
TYPE_CHECKING,
Optional,
TypeVar,
Type,
Union,
Iterable,
Collection,
overload,
)
import discord
from discord import app_commands
@ -55,10 +69,18 @@ if TYPE_CHECKING:
from discord.message import Message
from discord.abc import User, Snowflake
from ._types import (
_Bot,
BotT,
Check,
CoroFunc,
ContextT,
MaybeCoroFunc,
)
_Prefix = Union[Iterable[str], str]
_PrefixCallable = MaybeCoroFunc[[BotT, Message], _Prefix]
PrefixType = Union[_Prefix, _PrefixCallable[BotT]]
__all__ = (
'when_mentioned',
'when_mentioned_or',
@ -68,11 +90,9 @@ __all__ = (
T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context')
BT = TypeVar('BT', bound='Union[Bot, AutoShardedBot]')
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
def when_mentioned(bot: _Bot, msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@ -81,7 +101,7 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
def when_mentioned_or(*prefixes: str) -> Callable[[_Bot, Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@ -124,27 +144,33 @@ class _DefaultRepr:
return '<default-help-command>'
_default = _DefaultRepr()
_default: Any = _DefaultRepr()
class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, **options):
class BotBase(GroupMixin[None]):
def __init__(
self,
command_prefix: PrefixType[BotT],
help_command: HelpCommand = _default,
description: Optional[str] = None,
**options: Any,
) -> None:
super().__init__(**options)
self.command_prefix = command_prefix
self.command_prefix: PrefixType[BotT] = command_prefix
self.extra_events: Dict[str, List[CoroFunc]] = {}
# Self doesn't have the ClientT bound, but since this is a mixin it technically does
self.__tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) # type: ignore
self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {}
self._checks: List[Check] = []
self._check_once = []
self._before_invoke = None
self._after_invoke = None
self._help_command = None
self.description = inspect.cleandoc(description) if description else ''
self.owner_id = options.get('owner_id')
self.owner_ids = options.get('owner_ids', set())
self.strip_after_prefix = options.get('strip_after_prefix', False)
self._check_once: List[Check] = []
self._before_invoke: Optional[CoroFunc] = None
self._after_invoke: Optional[CoroFunc] = None
self._help_command: Optional[HelpCommand] = None
self.description: str = inspect.cleandoc(description) if description else ''
self.owner_id: Optional[int] = options.get('owner_id')
self.owner_ids: Optional[Collection[int]] = options.get('owner_ids', set())
self.strip_after_prefix: bool = options.get('strip_after_prefix', False)
if self.owner_id and self.owner_ids:
raise TypeError('Both owner_id and owner_ids are set.')
@ -182,7 +208,7 @@ class BotBase(GroupMixin):
await super().close() # type: ignore
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
async def on_command_error(self, context: Context[BotT], exception: errors.CommandError) -> None:
"""|coro|
The default command error handler provided by the bot.
@ -237,7 +263,7 @@ class BotBase(GroupMixin):
self.add_check(func) # type: ignore
return func
def add_check(self, func: Check, /, *, call_once: bool = False) -> None:
def add_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
"""Adds a global check to the bot.
This is the non-decorator interface to :meth:`.check`
@ -261,7 +287,7 @@ class BotBase(GroupMixin):
else:
self._checks.append(func)
def remove_check(self, func: Check, /, *, call_once: bool = False) -> None:
def remove_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
"""Removes a global check from the bot.
This function is idempotent and will not raise an exception
@ -324,7 +350,7 @@ class BotBase(GroupMixin):
self.add_check(func, call_once=True)
return func
async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool:
async def can_run(self, ctx: Context[BotT], *, call_once: bool = False) -> bool:
data = self._check_once if call_once else self._checks
if len(data) == 0:
@ -947,7 +973,7 @@ class BotBase(GroupMixin):
# if the load failed, the remnants should have been
# cleaned from the load_extension function call
# so let's load it from our old compiled library.
await lib.setup(self) # type: ignore
await lib.setup(self)
self.__extensions[name] = lib
# revert sys.modules back to normal and raise back to caller
@ -1015,11 +1041,12 @@ class BotBase(GroupMixin):
"""
prefix = ret = self.command_prefix
if callable(prefix):
ret = await discord.utils.maybe_coroutine(prefix, self, message)
# self will be a Bot or AutoShardedBot
ret = await discord.utils.maybe_coroutine(prefix, self, message) # type: ignore
if not isinstance(ret, str):
try:
ret = list(ret)
ret = list(ret) # type: ignore
except TypeError:
# It's possible that a generator raised this exception. Don't
# replace it with our own error if that's the case.
@ -1048,15 +1075,15 @@ class BotBase(GroupMixin):
self,
message: Message,
*,
cls: Type[CXT] = ...,
) -> CXT: # type: ignore
cls: Type[ContextT] = ...,
) -> ContextT:
...
async def get_context(
self,
message: Message,
*,
cls: Type[CXT] = MISSING,
cls: Type[ContextT] = MISSING,
) -> Any:
r"""|coro|
@ -1137,7 +1164,7 @@ class BotBase(GroupMixin):
ctx.command = self.all_commands.get(invoker)
return ctx
async def invoke(self, ctx: Context) -> None:
async def invoke(self, ctx: Context[BotT]) -> None:
"""|coro|
Invokes the command given under the invocation context and
@ -1189,9 +1216,10 @@ class BotBase(GroupMixin):
return
ctx = await self.get_context(message)
await self.invoke(ctx)
# the type of the invocation context's bot attribute will be correct
await self.invoke(ctx) # type: ignore
async def on_message(self, message):
async def on_message(self, message: Message) -> None:
await self.process_commands(message)

16
discord/ext/commands/cog.py

@ -30,7 +30,7 @@ from discord.utils import maybe_coroutine
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union
from ._types import _BaseCommand
from ._types import _BaseCommand, BotT
if TYPE_CHECKING:
from typing_extensions import Self
@ -112,7 +112,7 @@ class CogMeta(type):
__cog_name__: str
__cog_settings__: Dict[str, Any]
__cog_commands__: List[Command]
__cog_commands__: List[Command[Any, ..., Any]]
__cog_is_app_commands_group__: bool
__cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]]
__cog_listeners__: List[Tuple[str, str]]
@ -406,7 +406,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
def bot_check_once(self, ctx: Context) -> bool:
def bot_check_once(self, ctx: Context[BotT]) -> bool:
"""A special method that registers as a :meth:`.Bot.check_once`
check.
@ -416,7 +416,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
def bot_check(self, ctx: Context) -> bool:
def bot_check(self, ctx: Context[BotT]) -> bool:
"""A special method that registers as a :meth:`.Bot.check`
check.
@ -426,7 +426,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
def cog_check(self, ctx: Context) -> bool:
def cog_check(self, ctx: Context[BotT]) -> bool:
"""A special method that registers as a :func:`~discord.ext.commands.check`
for every command and subcommand in this cog.
@ -436,7 +436,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
async def cog_command_error(self, ctx: Context[BotT], error: Exception) -> None:
"""A special method that is called whenever an error
is dispatched inside this cog.
@ -455,7 +455,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
async def cog_before_invoke(self, ctx: Context) -> None:
async def cog_before_invoke(self, ctx: Context[BotT]) -> None:
"""A special method that acts as a cog local pre-invoke hook.
This is similar to :meth:`.Command.before_invoke`.
@ -470,7 +470,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
async def cog_after_invoke(self, ctx: Context) -> None:
async def cog_after_invoke(self, ctx: Context[BotT]) -> None:
"""A special method that acts as a cog local post-invoke hook.
This is similar to :meth:`.Command.after_invoke`.

15
discord/ext/commands/context.py

@ -28,6 +28,8 @@ import re
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union
from ._types import BotT
import discord.abc
import discord.utils
@ -59,7 +61,6 @@ MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog")
if TYPE_CHECKING:
@ -133,10 +134,10 @@ class Context(discord.abc.Messageable, Generic[BotT]):
args: List[Any] = MISSING,
kwargs: Dict[str, Any] = MISSING,
prefix: Optional[str] = None,
command: Optional[Command] = None,
command: Optional[Command[Any, ..., Any]] = None,
invoked_with: Optional[str] = None,
invoked_parents: List[str] = MISSING,
invoked_subcommand: Optional[Command] = None,
invoked_subcommand: Optional[Command[Any, ..., Any]] = None,
subcommand_passed: Optional[str] = None,
command_failed: bool = False,
current_parameter: Optional[inspect.Parameter] = None,
@ -146,11 +147,11 @@ class Context(discord.abc.Messageable, Generic[BotT]):
self.args: List[Any] = args or []
self.kwargs: Dict[str, Any] = kwargs or {}
self.prefix: Optional[str] = prefix
self.command: Optional[Command] = command
self.command: Optional[Command[Any, ..., Any]] = command
self.view: StringView = view
self.invoked_with: Optional[str] = invoked_with
self.invoked_parents: List[str] = invoked_parents or []
self.invoked_subcommand: Optional[Command] = invoked_subcommand
self.invoked_subcommand: Optional[Command[Any, ..., Any]] = invoked_subcommand
self.subcommand_passed: Optional[str] = subcommand_passed
self.command_failed: bool = command_failed
self.current_parameter: Optional[inspect.Parameter] = current_parameter
@ -361,7 +362,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
return None
cmd = cmd.copy()
cmd.context = self
cmd.context = self # type: ignore
if len(args) == 0:
await cmd.prepare_help_command(self, None)
mapping = cmd.get_bot_mapping()
@ -390,7 +391,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
try:
if hasattr(entity, '__cog_commands__'):
injected = wrap_callback(cmd.send_cog_help)
return await injected(entity)
return await injected(entity) # type: ignore
elif isinstance(entity, Group):
injected = wrap_callback(cmd.send_group_help)
return await injected(entity)

86
discord/ext/commands/converter.py

@ -41,7 +41,6 @@ from typing import (
Tuple,
Union,
runtime_checkable,
overload,
)
import discord
@ -51,9 +50,8 @@ if TYPE_CHECKING:
from .context import Context
from discord.state import Channel
from discord.threads import Thread
from .bot import Bot, AutoShardedBot
_Bot = TypeVar('_Bot', bound=Union[Bot, AutoShardedBot])
from ._types import BotT, _Bot
__all__ = (
@ -87,7 +85,7 @@ __all__ = (
)
def _get_from_guilds(bot, getter, argument):
def _get_from_guilds(bot: _Bot, getter: str, argument: Any) -> Any:
result = None
for guild in bot.guilds:
result = getattr(guild, getter)(argument)
@ -115,7 +113,7 @@ class Converter(Protocol[T_co]):
method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`.
"""
async def convert(self, ctx: Context, argument: str) -> T_co:
async def convert(self, ctx: Context[BotT], argument: str) -> T_co:
"""|coro|
The method to override to do conversion logic.
@ -163,7 +161,7 @@ class ObjectConverter(IDConverter[discord.Object]):
2. Lookup by member, role, or channel mention.
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Object:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Object:
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument)
if match is None:
@ -196,7 +194,7 @@ class MemberConverter(IDConverter[discord.Member]):
optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled.
"""
async def query_member_named(self, guild, argument):
async def query_member_named(self, guild: discord.Guild, argument: str) -> Optional[discord.Member]:
cache = guild._state.member_cache_flags.joined
if len(argument) > 5 and argument[-5] == '#':
username, _, discriminator = argument.rpartition('#')
@ -206,7 +204,7 @@ class MemberConverter(IDConverter[discord.Member]):
members = await guild.query_members(argument, limit=100, cache=cache)
return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members)
async def query_member_by_id(self, bot, guild, user_id):
async def query_member_by_id(self, bot: _Bot, guild: discord.Guild, user_id: int) -> Optional[discord.Member]:
ws = bot._get_websocket(shard_id=guild.shard_id)
cache = guild._state.member_cache_flags.joined
if ws.is_ratelimited():
@ -227,7 +225,7 @@ class MemberConverter(IDConverter[discord.Member]):
return None
return members[0]
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Member:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Member:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
guild = ctx.guild
@ -281,7 +279,7 @@ class UserConverter(IDConverter[discord.User]):
and it's not available in cache.
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.User:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.User:
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
result = None
state = ctx._state
@ -359,7 +357,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod
def _resolve_channel(
ctx: Context[_Bot], guild_id: Optional[int], channel_id: Optional[int]
ctx: Context[BotT], guild_id: Optional[int], channel_id: Optional[int]
) -> Optional[Union[Channel, Thread]]:
if channel_id is None:
# we were passed just a message id so we can assume the channel is the current context channel
@ -373,7 +371,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
return ctx.bot.get_channel(channel_id)
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialMessage:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialMessage:
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
channel = self._resolve_channel(ctx, guild_id, channel_id)
if not channel or not isinstance(channel, discord.abc.Messageable):
@ -396,7 +394,7 @@ class MessageConverter(IDConverter[discord.Message]):
Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Message:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Message:
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument)
message = ctx.bot._connection._get_message(message_id)
if message:
@ -427,11 +425,11 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.abc.GuildChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.abc.GuildChannel:
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel)
@staticmethod
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT:
def _resolve_channel(ctx: Context[BotT], argument: str, attribute: str, type: Type[CT]) -> CT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
@ -448,7 +446,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
def check(c):
return isinstance(c, type) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
result = discord.utils.find(check, bot.get_all_channels()) # type: ignore
else:
channel_id = int(match.group(1))
if guild:
@ -463,7 +461,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
return result
@staticmethod
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT:
def _resolve_thread(ctx: Context[BotT], argument: str, attribute: str, type: Type[TT]) -> TT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
@ -502,7 +500,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.TextChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.TextChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel)
@ -522,7 +520,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.VoiceChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.VoiceChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel)
@ -541,7 +539,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
3. Lookup by name
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StageChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.StageChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel)
@ -561,7 +559,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.CategoryChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.CategoryChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel)
@ -580,7 +578,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
.. versionadded:: 1.7
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StoreChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.StoreChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel)
@ -598,7 +596,7 @@ class ThreadConverter(IDConverter[discord.Thread]):
.. versionadded: 2.0
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Thread:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Thread:
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread)
@ -630,7 +628,7 @@ class ColourConverter(Converter[discord.Colour]):
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
def parse_hex_number(self, argument):
def parse_hex_number(self, argument: str) -> discord.Colour:
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
try:
value = int(arg, base=16)
@ -641,7 +639,7 @@ class ColourConverter(Converter[discord.Colour]):
else:
return discord.Color(value=value)
def parse_rgb_number(self, argument, number):
def parse_rgb_number(self, argument: str, number: str) -> int:
if number[-1] == '%':
value = int(number[:-1])
if not (0 <= value <= 100):
@ -653,7 +651,7 @@ class ColourConverter(Converter[discord.Colour]):
raise BadColourArgument(argument)
return value
def parse_rgb(self, argument, *, regex=RGB_REGEX):
def parse_rgb(self, argument: str, *, regex: re.Pattern[str] = RGB_REGEX) -> discord.Colour:
match = regex.match(argument)
if match is None:
raise BadColourArgument(argument)
@ -663,7 +661,7 @@ class ColourConverter(Converter[discord.Colour]):
blue = self.parse_rgb_number(argument, match.group('b'))
return discord.Color.from_rgb(red, green, blue)
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Colour:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Colour:
if argument[0] == '#':
return self.parse_hex_number(argument[1:])
@ -704,7 +702,7 @@ class RoleConverter(IDConverter[discord.Role]):
Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Role:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Role:
guild = ctx.guild
if not guild:
raise NoPrivateMessage()
@ -723,7 +721,7 @@ class RoleConverter(IDConverter[discord.Role]):
class GameConverter(Converter[discord.Game]):
"""Converts to :class:`~discord.Game`."""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Game:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Game:
return discord.Game(name=argument)
@ -736,7 +734,7 @@ class InviteConverter(Converter[discord.Invite]):
Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Invite:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Invite:
try:
invite = await ctx.bot.fetch_invite(argument)
return invite
@ -755,7 +753,7 @@ class GuildConverter(IDConverter[discord.Guild]):
.. versionadded:: 1.7
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Guild:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Guild:
match = self._get_id_match(argument)
result = None
@ -787,7 +785,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Emoji:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Emoji:
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument)
result = None
bot = ctx.bot
@ -821,7 +819,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialEmoji:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialEmoji:
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument)
if match:
@ -850,7 +848,7 @@ class GuildStickerConverter(IDConverter[discord.GuildSticker]):
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.GuildSticker:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.GuildSticker:
match = self._get_id_match(argument)
result = None
bot = ctx.bot
@ -890,7 +888,7 @@ class ScheduledEventConverter(IDConverter[discord.ScheduledEvent]):
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.ScheduledEvent:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.ScheduledEvent:
guild = ctx.guild
match = self._get_id_match(argument)
result = None
@ -967,7 +965,7 @@ class clean_content(Converter[str]):
self.escape_markdown = escape_markdown
self.remove_markdown = remove_markdown
async def convert(self, ctx: Context[_Bot], argument: str) -> str:
async def convert(self, ctx: Context[BotT], argument: str) -> str:
msg = ctx.message
if ctx.guild:
@ -1047,10 +1045,10 @@ class Greedy(List[T]):
__slots__ = ('converter',)
def __init__(self, *, converter: T):
self.converter = converter
def __init__(self, *, converter: T) -> None:
self.converter: T = converter
def __repr__(self):
def __repr__(self) -> str:
converter = getattr(self.converter, '__name__', repr(self.converter))
return f'Greedy[{converter}]'
@ -1099,11 +1097,11 @@ def get_converter(param: inspect.Parameter) -> Any:
_GenericAlias = type(List[T])
def is_generic_type(tp: Any, *, _GenericAlias: Type = _GenericAlias) -> bool:
return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias) # type: ignore
def is_generic_type(tp: Any, *, _GenericAlias: type = _GenericAlias) -> bool:
return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias)
CONVERTER_MAPPING: Dict[Type[Any], Any] = {
CONVERTER_MAPPING: Dict[type, Any] = {
discord.Object: ObjectConverter,
discord.Member: MemberConverter,
discord.User: UserConverter,
@ -1128,7 +1126,7 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = {
}
async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter):
async def _actual_conversion(ctx: Context[BotT], converter, argument: str, param: inspect.Parameter):
if converter is bool:
return _convert_to_bool(argument)
@ -1166,7 +1164,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter):
async def run_converters(ctx: Context[BotT], converter: Any, argument: str, param: inspect.Parameter) -> Any:
"""|coro|
Runs converters for a given converter, argument, and parameter.

2
discord/ext/commands/cooldowns.py

@ -220,7 +220,7 @@ class CooldownMapping:
return self._type
@classmethod
def from_cooldown(cls, rate, per, type) -> Self:
def from_cooldown(cls, rate: float, per: float, type: Callable[[Message], Any]) -> Self:
return cls(Cooldown(rate, per), type)
def _bucket_key(self, msg: Message) -> Any:

115
discord/ext/commands/core.py

@ -61,6 +61,8 @@ if TYPE_CHECKING:
from discord.message import Message
from ._types import (
BotT,
ContextT,
Coro,
CoroFunc,
Check,
@ -101,7 +103,6 @@ MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CogT = TypeVar('CogT', bound='Optional[Cog]')
CommandT = TypeVar('CommandT', bound='Command')
ContextT = TypeVar('ContextT', bound='Context')
# CHT = TypeVar('CHT', bound='Check')
GroupT = TypeVar('GroupT', bound='Group')
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
@ -159,9 +160,9 @@ def get_signature_parameters(
return params
def wrap_callback(coro):
def wrap_callback(coro: Callable[P, Coro[T]]) -> Callable[P, Coro[Optional[T]]]:
@functools.wraps(coro)
async def wrapped(*args, **kwargs):
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
try:
ret = await coro(*args, **kwargs)
except CommandError:
@ -175,9 +176,11 @@ def wrap_callback(coro):
return wrapped
def hooked_wrapped_callback(command, ctx, coro):
def hooked_wrapped_callback(
command: Command[Any, ..., Any], ctx: Context[BotT], coro: Callable[P, Coro[T]]
) -> Callable[P, Coro[Optional[T]]]:
@functools.wraps(coro)
async def wrapped(*args, **kwargs):
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
try:
ret = await coro(*args, **kwargs)
except CommandError:
@ -191,7 +194,7 @@ def hooked_wrapped_callback(command, ctx, coro):
raise CommandInvokeError(exc) from exc
finally:
if command._max_concurrency is not None:
await command._max_concurrency.release(ctx)
await command._max_concurrency.release(ctx.message)
await command.call_after_hooks(ctx)
return ret
@ -359,7 +362,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
except AttributeError:
checks = kwargs.get('checks', [])
self.checks: List[Check] = checks
self.checks: List[Check[ContextT]] = checks
try:
cooldown = func.__commands_cooldown__
@ -387,8 +390,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.cog: CogT = None
# bandaid for the fact that sometimes parent can be the bot instance
parent = kwargs.get('parent')
self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore
parent: Optional[GroupMixin[Any]] = kwargs.get('parent')
self.parent: Optional[GroupMixin[Any]] = parent if isinstance(parent, _BaseCommand) else None
self._before_invoke: Optional[Hook] = None
try:
@ -422,16 +425,16 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
) -> None:
self._callback = function
unwrap = unwrap_function(function)
self.module = unwrap.__module__
self.module: str = unwrap.__module__
try:
globalns = unwrap.__globals__
except AttributeError:
globalns = {}
self.params = get_signature_parameters(function, globalns)
self.params: Dict[str, inspect.Parameter] = get_signature_parameters(function, globalns)
def add_check(self, func: Check, /) -> None:
def add_check(self, func: Check[ContextT], /) -> None:
"""Adds a check to the command.
This is the non-decorator interface to :func:`.check`.
@ -450,7 +453,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.checks.append(func)
def remove_check(self, func: Check, /) -> None:
def remove_check(self, func: Check[ContextT], /) -> None:
"""Removes a check from the command.
This function is idempotent and will not raise an exception
@ -484,7 +487,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs))
self.cog = cog
async def __call__(self, context: Context, *args: P.args, **kwargs: P.kwargs) -> T:
async def __call__(self, context: Context[BotT], *args: P.args, **kwargs: P.kwargs) -> T:
"""|coro|
Calls the internal callback that the command holds.
@ -539,7 +542,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
else:
return self.copy()
async def dispatch_error(self, ctx: Context, error: Exception) -> None:
async def dispatch_error(self, ctx: Context[BotT], error: CommandError) -> None:
ctx.command_failed = True
cog = self.cog
try:
@ -549,7 +552,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
else:
injected = wrap_callback(coro)
if cog is not None:
await injected(cog, ctx, error)
await injected(cog, ctx, error) # type: ignore
else:
await injected(ctx, error)
@ -562,7 +565,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
finally:
ctx.bot.dispatch('command_error', ctx, error)
async def transform(self, ctx: Context, param: inspect.Parameter) -> Any:
async def transform(self, ctx: Context[BotT], param: inspect.Parameter) -> Any:
required = param.default is param.empty
converter = get_converter(param)
consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw
@ -610,7 +613,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
# type-checker fails to narrow argument
return await run_converters(ctx, converter, argument, param) # type: ignore
async def _transform_greedy_pos(self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any) -> Any:
async def _transform_greedy_pos(
self, ctx: Context[BotT], param: inspect.Parameter, required: bool, converter: Any
) -> Any:
view = ctx.view
result = []
while not view.eof:
@ -631,7 +636,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return param.default
return result
async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any:
async def _transform_greedy_var_pos(self, ctx: Context[BotT], param: inspect.Parameter, converter: Any) -> Any:
view = ctx.view
previous = view.index
try:
@ -669,7 +674,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return ' '.join(reversed(entries))
@property
def parents(self) -> List[Group]:
def parents(self) -> List[Group[Any, ..., Any]]:
"""List[:class:`Group`]: Retrieves the parents of this command.
If the command has no parents then it returns an empty :class:`list`.
@ -687,7 +692,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return entries
@property
def root_parent(self) -> Optional[Group]:
def root_parent(self) -> Optional[Group[Any, ..., Any]]:
"""Optional[:class:`Group`]: Retrieves the root parent of this command.
If the command has no parents then it returns ``None``.
@ -716,7 +721,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
def __str__(self) -> str:
return self.qualified_name
async def _parse_arguments(self, ctx: Context) -> None:
async def _parse_arguments(self, ctx: Context[BotT]) -> None:
ctx.args = [ctx] if self.cog is None else [self.cog, ctx]
ctx.kwargs = {}
args = ctx.args
@ -752,7 +757,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if not self.ignore_extra and not view.eof:
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
async def call_before_hooks(self, ctx: Context) -> None:
async def call_before_hooks(self, ctx: Context[BotT]) -> None:
# now that we're done preparing we can call the pre-command hooks
# first, call the command local hook:
cog = self.cog
@ -777,7 +782,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if hook is not None:
await hook(ctx)
async def call_after_hooks(self, ctx: Context) -> None:
async def call_after_hooks(self, ctx: Context[BotT]) -> None:
cog = self.cog
if self._after_invoke is not None:
instance = getattr(self._after_invoke, '__self__', cog)
@ -796,7 +801,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if hook is not None:
await hook(ctx)
def _prepare_cooldowns(self, ctx: Context) -> None:
def _prepare_cooldowns(self, ctx: Context[BotT]) -> None:
if self._buckets.valid:
dt = ctx.message.edited_at or ctx.message.created_at
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
@ -806,7 +811,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if retry_after:
raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore
async def prepare(self, ctx: Context) -> None:
async def prepare(self, ctx: Context[BotT]) -> None:
ctx.command = self
if not await self.can_run(ctx):
@ -830,7 +835,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
await self._max_concurrency.release(ctx) # type: ignore
raise
def is_on_cooldown(self, ctx: Context) -> bool:
def is_on_cooldown(self, ctx: Context[BotT]) -> bool:
"""Checks whether the command is currently on cooldown.
Parameters
@ -851,7 +856,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
return bucket.get_tokens(current) == 0
def reset_cooldown(self, ctx: Context) -> None:
def reset_cooldown(self, ctx: Context[BotT]) -> None:
"""Resets the cooldown on this command.
Parameters
@ -863,7 +868,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
bucket = self._buckets.get_bucket(ctx.message)
bucket.reset()
def get_cooldown_retry_after(self, ctx: Context) -> float:
def get_cooldown_retry_after(self, ctx: Context[BotT]) -> float:
"""Retrieves the amount of seconds before this command can be tried again.
.. versionadded:: 1.4
@ -887,7 +892,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return 0.0
async def invoke(self, ctx: Context) -> None:
async def invoke(self, ctx: Context[BotT]) -> None:
await self.prepare(ctx)
# terminate the invoked_subcommand chain.
@ -896,9 +901,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
ctx.invoked_subcommand = None
ctx.subcommand_passed = None
injected = hooked_wrapped_callback(self, ctx, self.callback)
await injected(*ctx.args, **ctx.kwargs)
await injected(*ctx.args, **ctx.kwargs) # type: ignore
async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None:
async def reinvoke(self, ctx: Context[BotT], *, call_hooks: bool = False) -> None:
ctx.command = self
await self._parse_arguments(ctx)
@ -936,7 +941,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.')
self.on_error: Error = coro
self.on_error: Error[Any] = coro
return coro
def has_error_handler(self) -> bool:
@ -1075,7 +1080,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return ' '.join(result)
async def can_run(self, ctx: Context) -> bool:
async def can_run(self, ctx: Context[BotT]) -> bool:
"""|coro|
Checks if the command can be executed by checking all the predicates
@ -1341,7 +1346,7 @@ class GroupMixin(Generic[CogT]):
def command(
self,
name: str = MISSING,
cls: Type[Command] = MISSING,
cls: Type[Command[Any, ..., Any]] = MISSING,
*args: Any,
**kwargs: Any,
) -> Any:
@ -1401,7 +1406,7 @@ class GroupMixin(Generic[CogT]):
def group(
self,
name: str = MISSING,
cls: Type[Group] = MISSING,
cls: Type[Group[Any, ..., Any]] = MISSING,
*args: Any,
**kwargs: Any,
) -> Any:
@ -1461,9 +1466,9 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
ret = super().copy()
for cmd in self.commands:
ret.add_command(cmd.copy())
return ret # type: ignore
return ret
async def invoke(self, ctx: Context) -> None:
async def invoke(self, ctx: Context[BotT]) -> None:
ctx.invoked_subcommand = None
ctx.subcommand_passed = None
early_invoke = not self.invoke_without_command
@ -1481,7 +1486,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
if early_invoke:
injected = hooked_wrapped_callback(self, ctx, self.callback)
await injected(*ctx.args, **ctx.kwargs)
await injected(*ctx.args, **ctx.kwargs) # type: ignore
ctx.invoked_parents.append(ctx.invoked_with) # type: ignore
@ -1494,7 +1499,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
view.previous = previous
await super().invoke(ctx)
async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None:
async def reinvoke(self, ctx: Context[BotT], *, call_hooks: bool = False) -> None:
ctx.invoked_subcommand = None
early_invoke = not self.invoke_without_command
if early_invoke:
@ -1592,7 +1597,7 @@ def command(
def command(
name: str = MISSING,
cls: Type[Command] = MISSING,
cls: Type[Command[Any, ..., Any]] = MISSING,
**attrs: Any,
) -> Any:
"""A decorator that transforms a function into a :class:`.Command`
@ -1662,7 +1667,7 @@ def group(
def group(
name: str = MISSING,
cls: Type[Group] = MISSING,
cls: Type[Group[Any, ..., Any]] = MISSING,
**attrs: Any,
) -> Any:
"""A decorator that transforms a function into a :class:`.Group`.
@ -1679,7 +1684,7 @@ def group(
return command(name=name, cls=cls, **attrs)
def check(predicate: Check) -> Callable[[T], T]:
def check(predicate: Check[ContextT]) -> Callable[[T], T]:
r"""A decorator that adds a check to the :class:`.Command` or its
subclasses. These checks could be accessed via :attr:`.Command.checks`.
@ -1774,7 +1779,7 @@ def check(predicate: Check) -> Callable[[T], T]:
return decorator # type: ignore
def check_any(*checks: Check) -> Callable[[T], T]:
def check_any(*checks: Check[ContextT]) -> Callable[[T], T]:
r"""A :func:`check` that is added that checks if any of the checks passed
will pass, i.e. using logical OR.
@ -1827,7 +1832,7 @@ def check_any(*checks: Check) -> Callable[[T], T]:
else:
unwrapped.append(pred)
async def predicate(ctx: Context) -> bool:
async def predicate(ctx: Context[BotT]) -> bool:
errors = []
for func in unwrapped:
try:
@ -1870,7 +1875,7 @@ def has_role(item: Union[int, str]) -> Callable[[T], T]:
The name or ID of the role to check.
"""
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
if ctx.guild is None:
raise NoPrivateMessage()
@ -1923,7 +1928,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
raise NoPrivateMessage()
# ctx.guild is None doesn't narrow ctx.author to Member
getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore
getter = functools.partial(discord.utils.get, ctx.author.roles)
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
return True
raise MissingAnyRole(list(items))
@ -2022,7 +2027,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
ch = ctx.channel
permissions = ch.permissions_for(ctx.author) # type: ignore
@ -2048,7 +2053,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
guild = ctx.guild
me = guild.me if guild is not None else ctx.bot.user
permissions = ctx.channel.permissions_for(me) # type: ignore
@ -2077,7 +2082,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
if not ctx.guild:
raise NoPrivateMessage
@ -2103,7 +2108,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
if not ctx.guild:
raise NoPrivateMessage
@ -2129,7 +2134,7 @@ def dm_only() -> Callable[[T], T]:
.. versionadded:: 1.1
"""
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
if ctx.guild is not None:
raise PrivateMessageOnly()
return True
@ -2146,7 +2151,7 @@ def guild_only() -> Callable[[T], T]:
that is inherited from :exc:`.CheckFailure`.
"""
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
if ctx.guild is None:
raise NoPrivateMessage()
return True
@ -2164,7 +2169,7 @@ def is_owner() -> Callable[[T], T]:
from :exc:`.CheckFailure`.
"""
async def predicate(ctx: Context) -> bool:
async def predicate(ctx: Context[BotT]) -> bool:
if not await ctx.bot.is_owner(ctx.author):
raise NotOwner('You do not own this bot.')
return True
@ -2184,7 +2189,7 @@ def is_nsfw() -> Callable[[T], T]:
DM channels will also now pass this check.
"""
def pred(ctx: Context) -> bool:
def pred(ctx: Context[BotT]) -> bool:
ch = ctx.channel
if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()):
return True

14
discord/ext/commands/errors.py

@ -39,6 +39,8 @@ if TYPE_CHECKING:
from discord.threads import Thread
from discord.types.snowflake import Snowflake, SnowflakeList
from ._types import BotT
__all__ = (
'CommandError',
@ -135,8 +137,8 @@ class ConversionError(CommandError):
the ``__cause__`` attribute.
"""
def __init__(self, converter: Converter, original: Exception) -> None:
self.converter: Converter = converter
def __init__(self, converter: Converter[Any], original: Exception) -> None:
self.converter: Converter[Any] = converter
self.original: Exception = original
@ -224,9 +226,9 @@ class CheckAnyFailure(CheckFailure):
A list of check predicates that failed.
"""
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None:
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context[BotT]], bool]]) -> None:
self.checks: List[CheckFailure] = checks
self.errors: List[Callable[[Context], bool]] = errors
self.errors: List[Callable[[Context[BotT]], bool]] = errors
super().__init__('You do not have permission to run this command.')
@ -807,9 +809,9 @@ class BadUnionArgument(UserInputError):
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None:
def __init__(self, param: Parameter, converters: Tuple[type, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param
self.converters: Tuple[Type, ...] = converters
self.converters: Tuple[type, ...] = converters
self.errors: List[CommandError] = errors
def _get_name(x):

18
discord/ext/commands/flags.py

@ -49,8 +49,6 @@ from typing import (
Tuple,
List,
Any,
Type,
TypeVar,
Union,
)
@ -70,6 +68,8 @@ if TYPE_CHECKING:
from .context import Context
from ._types import BotT
@dataclass
class Flag:
@ -148,7 +148,7 @@ def flag(
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
def validate_flag_name(name: str, forbidden: Set[str]):
def validate_flag_name(name: str, forbidden: Set[str]) -> None:
if not name:
raise ValueError('flag names should not be empty')
@ -348,7 +348,7 @@ class FlagsMeta(type):
return type.__new__(cls, name, bases, attrs)
async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
async def tuple_convert_all(ctx: Context[BotT], argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
view = StringView(argument)
results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore
@ -373,7 +373,7 @@ async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter:
return tuple(results)
async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
async def tuple_convert_flag(ctx: Context[BotT], argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
view = StringView(argument)
results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore
@ -401,7 +401,7 @@ async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters
return tuple(results)
async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -> Any:
async def convert_flag(ctx: Context[BotT], argument: str, flag: Flag, annotation: Any = None) -> Any:
param: inspect.Parameter = ctx.current_parameter # type: ignore
annotation = annotation or flag.annotation
try:
@ -480,7 +480,7 @@ class FlagConverter(metaclass=FlagsMeta):
yield (flag.name, getattr(self, flag.attribute))
@classmethod
async def _construct_default(cls, ctx: Context) -> Self:
async def _construct_default(cls, ctx: Context[BotT]) -> Self:
self = cls.__new__(cls)
flags = cls.__commands_flags__
for flag in flags.values():
@ -546,7 +546,7 @@ class FlagConverter(metaclass=FlagsMeta):
return result
@classmethod
async def convert(cls, ctx: Context, argument: str) -> Self:
async def convert(cls, ctx: Context[BotT], argument: str) -> Self:
"""|coro|
The method that actually converters an argument to the flag mapping.
@ -610,7 +610,7 @@ class FlagConverter(metaclass=FlagsMeta):
values = [await convert_flag(ctx, value, flag) for value in values]
if flag.cast_to_dict:
values = dict(values) # type: ignore
values = dict(values)
setattr(self, flag.attribute, values)

293
discord/ext/commands/help.py

@ -22,13 +22,27 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import itertools
import copy
import functools
import inspect
import re
from typing import Optional, TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Optional,
Generator,
List,
TypeVar,
Callable,
Any,
Dict,
Tuple,
Iterable,
Sequence,
Mapping,
)
import discord.utils
@ -36,7 +50,21 @@ from .core import Group, Command, get_signature_parameters
from .errors import CommandError
if TYPE_CHECKING:
from typing_extensions import Self
import inspect
import discord.abc
from .bot import BotBase
from .context import Context
from .cog import Cog
from ._types import (
Check,
ContextT,
BotT,
_Bot,
)
__all__ = (
'Paginator',
@ -45,7 +73,9 @@ __all__ = (
'MinimalHelpCommand',
)
MISSING = discord.utils.MISSING
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
MISSING: Any = discord.utils.MISSING
# help -> shows info of bot on top/bottom and lists subcommands
# help command -> shows detailed info of command
@ -80,10 +110,10 @@ class Paginator:
Attributes
-----------
prefix: :class:`str`
The prefix inserted to every page. e.g. three backticks.
suffix: :class:`str`
The suffix appended at the end of every page. e.g. three backticks.
prefix: Optional[:class:`str`]
The prefix inserted to every page. e.g. three backticks, if any.
suffix: Optional[:class:`str`]
The suffix appended at the end of every page. e.g. three backticks, if any.
max_size: :class:`int`
The maximum amount of codepoints allowed in a page.
linesep: :class:`str`
@ -91,36 +121,38 @@ class Paginator:
.. versionadded:: 1.7
"""
def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'):
self.prefix = prefix
self.suffix = suffix
self.max_size = max_size
self.linesep = linesep
def __init__(
self, prefix: Optional[str] = '```', suffix: Optional[str] = '```', max_size: int = 2000, linesep: str = '\n'
) -> None:
self.prefix: Optional[str] = prefix
self.suffix: Optional[str] = suffix
self.max_size: int = max_size
self.linesep: str = linesep
self.clear()
def clear(self):
def clear(self) -> None:
"""Clears the paginator to have no pages."""
if self.prefix is not None:
self._current_page = [self.prefix]
self._count = len(self.prefix) + self._linesep_len # prefix + newline
self._current_page: List[str] = [self.prefix]
self._count: int = len(self.prefix) + self._linesep_len # prefix + newline
else:
self._current_page = []
self._count = 0
self._pages = []
self._pages: List[str] = []
@property
def _prefix_len(self):
def _prefix_len(self) -> int:
return len(self.prefix) if self.prefix else 0
@property
def _suffix_len(self):
def _suffix_len(self) -> int:
return len(self.suffix) if self.suffix else 0
@property
def _linesep_len(self):
def _linesep_len(self) -> int:
return len(self.linesep)
def add_line(self, line='', *, empty=False):
def add_line(self, line: str = '', *, empty: bool = False) -> None:
"""Adds a line to the current page.
If the line exceeds the :attr:`max_size` then an exception
@ -152,7 +184,7 @@ class Paginator:
self._current_page.append('')
self._count += self._linesep_len
def close_page(self):
def close_page(self) -> None:
"""Prematurely terminate a page."""
if self.suffix is not None:
self._current_page.append(self.suffix)
@ -165,36 +197,38 @@ class Paginator:
self._current_page = []
self._count = 0
def __len__(self):
def __len__(self) -> int:
total = sum(len(p) for p in self._pages)
return total + self._count
@property
def pages(self):
def pages(self) -> List[str]:
"""List[:class:`str`]: Returns the rendered list of pages."""
# we have more than just the prefix in our current page
if len(self._current_page) > (0 if self.prefix is None else 1):
self.close_page()
return self._pages
def __repr__(self):
def __repr__(self) -> str:
fmt = '<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>'
return fmt.format(self)
def _not_overridden(f):
def _not_overridden(f: FuncT) -> FuncT:
f.__help_command_not_overridden__ = True
return f
class _HelpCommandImpl(Command):
def __init__(self, inject, *args, **kwargs):
def __init__(self, inject: HelpCommand, *args: Any, **kwargs: Any) -> None:
super().__init__(inject.command_callback, *args, **kwargs)
self._original = inject
self._injected = inject
self.params = get_signature_parameters(inject.command_callback, globals(), skip_parameters=1)
self._original: HelpCommand = inject
self._injected: HelpCommand = inject
self.params: Dict[str, inspect.Parameter] = get_signature_parameters(
inject.command_callback, globals(), skip_parameters=1
)
async def prepare(self, ctx):
async def prepare(self, ctx: Context[Any]) -> None:
self._injected = injected = self._original.copy()
injected.context = ctx
self.callback = injected.command_callback
@ -209,7 +243,7 @@ class _HelpCommandImpl(Command):
await super().prepare(ctx)
async def _parse_arguments(self, ctx):
async def _parse_arguments(self, ctx: Context[BotT]) -> None:
# Make the parser think we don't have a cog so it doesn't
# inject the parameter into `ctx.args`.
original_cog = self.cog
@ -219,22 +253,26 @@ class _HelpCommandImpl(Command):
finally:
self.cog = original_cog
async def _on_error_cog_implementation(self, dummy, ctx, error):
async def _on_error_cog_implementation(self, _, ctx: Context[BotT], error: CommandError) -> None:
await self._injected.on_help_command_error(ctx, error)
def _inject_into_cog(self, cog):
def _inject_into_cog(self, cog: Cog) -> None:
# Warning: hacky
# Make the cog think that get_commands returns this command
# as well if we inject it without modifying __cog_commands__
# since that's used for the injection and ejection of cogs.
def wrapped_get_commands(*, _original=cog.get_commands):
def wrapped_get_commands(
*, _original: Callable[[], List[Command[Any, ..., Any]]] = cog.get_commands
) -> List[Command[Any, ..., Any]]:
ret = _original()
ret.append(self)
return ret
# Ditto here
def wrapped_walk_commands(*, _original=cog.walk_commands):
def wrapped_walk_commands(
*, _original: Callable[[], Generator[Command[Any, ..., Any], None, None]] = cog.walk_commands
):
yield from _original()
yield self
@ -244,7 +282,7 @@ class _HelpCommandImpl(Command):
cog.walk_commands = wrapped_walk_commands
self.cog = cog
def _eject_cog(self):
def _eject_cog(self) -> None:
if self.cog is None:
return
@ -298,7 +336,11 @@ class HelpCommand:
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys()))
def __new__(cls, *args, **kwargs):
if TYPE_CHECKING:
__original_kwargs__: Dict[str, Any]
__original_args__: Tuple[Any, ...]
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
# To prevent race conditions of a single instance while also allowing
# for settings to be passed the original arguments passed must be assigned
# to allow for easier copies (which will be made when the help command is actually called)
@ -314,30 +356,31 @@ class HelpCommand:
self.__original_args__ = deepcopy(args)
return self
def __init__(self, **options):
self.show_hidden = options.pop('show_hidden', False)
self.verify_checks = options.pop('verify_checks', True)
def __init__(self, **options: Any) -> None:
self.show_hidden: bool = options.pop('show_hidden', False)
self.verify_checks: bool = options.pop('verify_checks', True)
self.command_attrs: Dict[str, Any]
self.command_attrs = attrs = options.pop('command_attrs', {})
attrs.setdefault('name', 'help')
attrs.setdefault('help', 'Shows this message')
self.context: Context = MISSING
self.context: Context[_Bot] = MISSING
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
def copy(self):
def copy(self) -> Self:
obj = self.__class__(*self.__original_args__, **self.__original_kwargs__)
obj._command_impl = self._command_impl
return obj
def _add_to_bot(self, bot):
def _add_to_bot(self, bot: BotBase) -> None:
command = _HelpCommandImpl(self, **self.command_attrs)
bot.add_command(command)
self._command_impl = command
def _remove_from_bot(self, bot):
def _remove_from_bot(self, bot: BotBase) -> None:
bot.remove_command(self._command_impl.name)
self._command_impl._eject_cog()
def add_check(self, func, /):
def add_check(self, func: Check[ContextT], /) -> None:
"""
Adds a check to the help command.
@ -355,7 +398,7 @@ class HelpCommand:
self._command_impl.add_check(func)
def remove_check(self, func, /):
def remove_check(self, func: Check[ContextT], /) -> None:
"""
Removes a check from the help command.
@ -376,15 +419,15 @@ class HelpCommand:
self._command_impl.remove_check(func)
def get_bot_mapping(self):
def get_bot_mapping(self) -> Dict[Optional[Cog], List[Command[Any, ..., Any]]]:
"""Retrieves the bot mapping passed to :meth:`send_bot_help`."""
bot = self.context.bot
mapping = {cog: cog.get_commands() for cog in bot.cogs.values()}
mapping: Dict[Optional[Cog], List[Command[Any, ..., Any]]] = {cog: cog.get_commands() for cog in bot.cogs.values()}
mapping[None] = [c for c in bot.commands if c.cog is None]
return mapping
@property
def invoked_with(self):
def invoked_with(self) -> Optional[str]:
"""Similar to :attr:`Context.invoked_with` except properly handles
the case where :meth:`Context.send_help` is used.
@ -395,7 +438,7 @@ class HelpCommand:
Returns
---------
:class:`str`
Optional[:class:`str`]
The command name that triggered this invocation.
"""
command_name = self._command_impl.name
@ -404,7 +447,7 @@ class HelpCommand:
return command_name
return ctx.invoked_with
def get_command_signature(self, command):
def get_command_signature(self, command: Command[Any, ..., Any]) -> str:
"""Retrieves the signature portion of the help page.
Parameters
@ -418,14 +461,14 @@ class HelpCommand:
The signature for the command.
"""
parent = command.parent
parent: Optional[Group[Any, ..., Any]] = command.parent # type: ignore - the parent will be a Group
entries = []
while parent is not None:
if not parent.signature or parent.invoke_without_command:
entries.append(parent.name)
else:
entries.append(parent.name + ' ' + parent.signature)
parent = parent.parent
parent = parent.parent # type: ignore
parent_sig = ' '.join(reversed(entries))
if len(command.aliases) > 0:
@ -439,7 +482,7 @@ class HelpCommand:
return f'{self.context.clean_prefix}{alias} {command.signature}'
def remove_mentions(self, string):
def remove_mentions(self, string: str) -> str:
"""Removes mentions from the string to prevent abuse.
This includes ``@everyone``, ``@here``, member mentions and role mentions.
@ -450,13 +493,13 @@ class HelpCommand:
The string with mentions removed.
"""
def replace(obj, *, transforms=self.MENTION_TRANSFORMS):
def replace(obj: re.Match, *, transforms: Dict[str, str] = self.MENTION_TRANSFORMS) -> str:
return transforms.get(obj.group(0), '@invalid')
return self.MENTION_PATTERN.sub(replace, string)
@property
def cog(self):
def cog(self) -> Optional[Cog]:
"""A property for retrieving or setting the cog for the help command.
When a cog is set for the help command, it is as-if the help command
@ -473,7 +516,7 @@ class HelpCommand:
return self._command_impl.cog
@cog.setter
def cog(self, cog):
def cog(self, cog: Optional[Cog]) -> None:
# Remove whatever cog is currently valid, if any
self._command_impl._eject_cog()
@ -481,7 +524,7 @@ class HelpCommand:
if cog is not None:
self._command_impl._inject_into_cog(cog)
def command_not_found(self, string):
def command_not_found(self, string: str) -> str:
"""|maybecoro|
A method called when a command is not found in the help command.
@ -502,7 +545,7 @@ class HelpCommand:
"""
return f'No command called "{string}" found.'
def subcommand_not_found(self, command, string):
def subcommand_not_found(self, command: Command[Any, ..., Any], string: str) -> str:
"""|maybecoro|
A method called when a command did not have a subcommand requested in the help command.
@ -532,7 +575,13 @@ class HelpCommand:
return f'Command "{command.qualified_name}" has no subcommand named {string}'
return f'Command "{command.qualified_name}" has no subcommands.'
async def filter_commands(self, commands, *, sort=False, key=None):
async def filter_commands(
self,
commands: Iterable[Command[Any, ..., Any]],
*,
sort: bool = False,
key: Optional[Callable[[Command[Any, ..., Any]], Any]] = None,
) -> List[Command[Any, ..., Any]]:
"""|coro|
Returns a filtered list of commands and optionally sorts them.
@ -546,7 +595,7 @@ class HelpCommand:
An iterable of commands that are getting filtered.
sort: :class:`bool`
Whether to sort the result.
key: Optional[Callable[:class:`Command`, Any]]
key: Optional[Callable[[:class:`Command`], Any]]
An optional key function to pass to :func:`py:sorted` that
takes a :class:`Command` as its sole parameter. If ``sort`` is
passed as ``True`` then this will default as the command name.
@ -565,14 +614,14 @@ class HelpCommand:
if self.verify_checks is False:
# if we do not need to verify the checks then we can just
# run it straight through normally without using await.
return sorted(iterator, key=key) if sort else list(iterator)
return sorted(iterator, key=key) if sort else list(iterator) # type: ignore - the key shouldn't be None
if self.verify_checks is None and not self.context.guild:
# if verify_checks is None and we're in a DM, don't verify
return sorted(iterator, key=key) if sort else list(iterator)
return sorted(iterator, key=key) if sort else list(iterator) # type: ignore
# if we're here then we need to check every command if it can run
async def predicate(cmd):
async def predicate(cmd: Command[Any, ..., Any]) -> bool:
try:
return await cmd.can_run(self.context)
except CommandError:
@ -588,7 +637,7 @@ class HelpCommand:
ret.sort(key=key)
return ret
def get_max_size(self, commands):
def get_max_size(self, commands: Sequence[Command[Any, ..., Any]]) -> int:
"""Returns the largest name length of the specified command list.
Parameters
@ -605,7 +654,7 @@ class HelpCommand:
as_lengths = (discord.utils._string_width(c.name) for c in commands)
return max(as_lengths, default=0)
def get_destination(self):
def get_destination(self) -> discord.abc.MessageableChannel:
"""Returns the :class:`~discord.abc.Messageable` where the help command will be output.
You can override this method to customise the behaviour.
@ -619,7 +668,7 @@ class HelpCommand:
"""
return self.context.channel
async def send_error_message(self, error):
async def send_error_message(self, error: str) -> None:
"""|coro|
Handles the implementation when an error happens in the help command.
@ -644,7 +693,7 @@ class HelpCommand:
await destination.send(error)
@_not_overridden
async def on_help_command_error(self, ctx, error):
async def on_help_command_error(self, ctx: Context[BotT], error: CommandError) -> None:
"""|coro|
The help command's error handler, as specified by :ref:`ext_commands_error_handler`.
@ -664,7 +713,7 @@ class HelpCommand:
"""
pass
async def send_bot_help(self, mapping):
async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, ..., Any]]]) -> None:
"""|coro|
Handles the implementation of the bot command page in the help command.
@ -693,7 +742,7 @@ class HelpCommand:
"""
return None
async def send_cog_help(self, cog):
async def send_cog_help(self, cog: Cog) -> None:
"""|coro|
Handles the implementation of the cog page in the help command.
@ -721,7 +770,7 @@ class HelpCommand:
"""
return None
async def send_group_help(self, group):
async def send_group_help(self, group: Group[Any, ..., Any]) -> None:
"""|coro|
Handles the implementation of the group page in the help command.
@ -749,7 +798,7 @@ class HelpCommand:
"""
return None
async def send_command_help(self, command):
async def send_command_help(self, command: Command[Any, ..., Any]) -> None:
"""|coro|
Handles the implementation of the single command page in the help command.
@ -787,7 +836,7 @@ class HelpCommand:
"""
return None
async def prepare_help_command(self, ctx, command=None):
async def prepare_help_command(self, ctx: Context[BotT], command: Optional[str] = None) -> None:
"""|coro|
A low level method that can be used to prepare the help command
@ -811,7 +860,7 @@ class HelpCommand:
"""
pass
async def command_callback(self, ctx, *, command=None):
async def command_callback(self, ctx: Context[BotT], *, command: Optional[str] = None) -> None:
"""|coro|
The actual implementation of the help command.
@ -856,7 +905,7 @@ class HelpCommand:
for key in keys[1:]:
try:
found = cmd.all_commands.get(key)
found = cmd.all_commands.get(key) # type: ignore
except AttributeError:
string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key))
return await self.send_error_message(string)
@ -908,28 +957,28 @@ class DefaultHelpCommand(HelpCommand):
The paginator used to paginate the help command output.
"""
def __init__(self, **options):
self.width = options.pop('width', 80)
self.indent = options.pop('indent', 2)
self.sort_commands = options.pop('sort_commands', True)
self.dm_help = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
self.commands_heading = options.pop('commands_heading', "Commands:")
self.no_category = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None)
def __init__(self, **options: Any) -> None:
self.width: int = options.pop('width', 80)
self.indent: int = options.pop('indent', 2)
self.sort_commands: bool = options.pop('sort_commands', True)
self.dm_help: bool = options.pop('dm_help', False)
self.dm_help_threshold: int = options.pop('dm_help_threshold', 1000)
self.commands_heading: str = options.pop('commands_heading', "Commands:")
self.no_category: str = options.pop('no_category', 'No Category')
self.paginator: Paginator = options.pop('paginator', None)
if self.paginator is None:
self.paginator = Paginator()
self.paginator: Paginator = Paginator()
super().__init__(**options)
def shorten_text(self, text):
def shorten_text(self, text: str) -> str:
""":class:`str`: Shortens text to fit into the :attr:`width`."""
if len(text) > self.width:
return text[: self.width - 3].rstrip() + '...'
return text
def get_ending_note(self):
def get_ending_note(self) -> str:
""":class:`str`: Returns help command's ending note. This is mainly useful to override for i18n purposes."""
command_name = self.invoked_with
return (
@ -937,7 +986,9 @@ class DefaultHelpCommand(HelpCommand):
f"You can also type {self.context.clean_prefix}{command_name} category for more info on a category."
)
def add_indented_commands(self, commands, *, heading, max_size=None):
def add_indented_commands(
self, commands: Sequence[Command[Any, ..., Any]], *, heading: str, max_size: Optional[int] = None
) -> None:
"""Indents a list of commands after the specified heading.
The formatting is added to the :attr:`paginator`.
@ -973,13 +1024,13 @@ class DefaultHelpCommand(HelpCommand):
entry = f'{self.indent * " "}{name:<{width}} {command.short_doc}'
self.paginator.add_line(self.shorten_text(entry))
async def send_pages(self):
async def send_pages(self) -> None:
"""A helper utility to send the page output from :attr:`paginator` to the destination."""
destination = self.get_destination()
for page in self.paginator.pages:
await destination.send(page)
def add_command_formatting(self, command):
def add_command_formatting(self, command: Command[Any, ..., Any]) -> None:
"""A utility function to format the non-indented block of commands and groups.
Parameters
@ -1002,7 +1053,7 @@ class DefaultHelpCommand(HelpCommand):
self.paginator.add_line(line)
self.paginator.add_line()
def get_destination(self):
def get_destination(self) -> discord.abc.Messageable:
ctx = self.context
if self.dm_help is True:
return ctx.author
@ -1011,11 +1062,11 @@ class DefaultHelpCommand(HelpCommand):
else:
return ctx.channel
async def prepare_help_command(self, ctx, command):
async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None:
self.paginator.clear()
await super().prepare_help_command(ctx, command)
async def send_bot_help(self, mapping):
async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, ..., Any]]]) -> None:
ctx = self.context
bot = ctx.bot
@ -1045,12 +1096,12 @@ class DefaultHelpCommand(HelpCommand):
await self.send_pages()
async def send_command_help(self, command):
async def send_command_help(self, command: Command[Any, ..., Any]) -> None:
self.add_command_formatting(command)
self.paginator.close_page()
await self.send_pages()
async def send_group_help(self, group):
async def send_group_help(self, group: Group[Any, ..., Any]) -> None:
self.add_command_formatting(group)
filtered = await self.filter_commands(group.commands, sort=self.sort_commands)
@ -1064,7 +1115,7 @@ class DefaultHelpCommand(HelpCommand):
await self.send_pages()
async def send_cog_help(self, cog):
async def send_cog_help(self, cog: Cog) -> None:
if cog.description:
self.paginator.add_line(cog.description, empty=True)
@ -1111,27 +1162,27 @@ class MinimalHelpCommand(HelpCommand):
The paginator used to paginate the help command output.
"""
def __init__(self, **options):
self.sort_commands = options.pop('sort_commands', True)
self.commands_heading = options.pop('commands_heading', "Commands")
self.dm_help = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
self.aliases_heading = options.pop('aliases_heading', "Aliases:")
self.no_category = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None)
def __init__(self, **options: Any) -> None:
self.sort_commands: bool = options.pop('sort_commands', True)
self.commands_heading: str = options.pop('commands_heading', "Commands")
self.dm_help: bool = options.pop('dm_help', False)
self.dm_help_threshold: int = options.pop('dm_help_threshold', 1000)
self.aliases_heading: str = options.pop('aliases_heading', "Aliases:")
self.no_category: str = options.pop('no_category', 'No Category')
self.paginator: Paginator = options.pop('paginator', None)
if self.paginator is None:
self.paginator = Paginator(suffix=None, prefix=None)
self.paginator: Paginator = Paginator(suffix=None, prefix=None)
super().__init__(**options)
async def send_pages(self):
async def send_pages(self) -> None:
"""A helper utility to send the page output from :attr:`paginator` to the destination."""
destination = self.get_destination()
for page in self.paginator.pages:
await destination.send(page)
def get_opening_note(self):
def get_opening_note(self) -> str:
"""Returns help command's opening note. This is mainly useful to override for i18n purposes.
The default implementation returns ::
@ -1150,10 +1201,10 @@ class MinimalHelpCommand(HelpCommand):
f"You can also use `{self.context.clean_prefix}{command_name} [category]` for more info on a category."
)
def get_command_signature(self, command):
def get_command_signature(self, command: Command[Any, ..., Any]) -> str:
return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}'
def get_ending_note(self):
def get_ending_note(self) -> str:
"""Return the help command's ending note. This is mainly useful to override for i18n purposes.
The default implementation does nothing.
@ -1163,9 +1214,9 @@ class MinimalHelpCommand(HelpCommand):
:class:`str`
The help command ending note.
"""
return None
return ''
def add_bot_commands_formatting(self, commands, heading):
def add_bot_commands_formatting(self, commands: Sequence[Command[Any, ..., Any]], heading: str) -> None:
"""Adds the minified bot heading with commands to the output.
The formatting should be added to the :attr:`paginator`.
@ -1186,7 +1237,7 @@ class MinimalHelpCommand(HelpCommand):
self.paginator.add_line(f'__**{heading}**__')
self.paginator.add_line(joined)
def add_subcommand_formatting(self, command):
def add_subcommand_formatting(self, command: Command[Any, ..., Any]) -> None:
"""Adds formatting information on a subcommand.
The formatting should be added to the :attr:`paginator`.
@ -1202,7 +1253,7 @@ class MinimalHelpCommand(HelpCommand):
fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}'
self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc))
def add_aliases_formatting(self, aliases):
def add_aliases_formatting(self, aliases: Sequence[str]) -> None:
"""Adds the formatting information on a command's aliases.
The formatting should be added to the :attr:`paginator`.
@ -1219,7 +1270,7 @@ class MinimalHelpCommand(HelpCommand):
"""
self.paginator.add_line(f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True)
def add_command_formatting(self, command):
def add_command_formatting(self, command: Command[Any, ..., Any]) -> None:
"""A utility function to format commands and groups.
Parameters
@ -1246,7 +1297,7 @@ class MinimalHelpCommand(HelpCommand):
self.paginator.add_line(line)
self.paginator.add_line()
def get_destination(self):
def get_destination(self) -> discord.abc.Messageable:
ctx = self.context
if self.dm_help is True:
return ctx.author
@ -1255,11 +1306,11 @@ class MinimalHelpCommand(HelpCommand):
else:
return ctx.channel
async def prepare_help_command(self, ctx, command):
async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None:
self.paginator.clear()
await super().prepare_help_command(ctx, command)
async def send_bot_help(self, mapping):
async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, ..., Any]]]) -> None:
ctx = self.context
bot = ctx.bot
@ -1272,7 +1323,7 @@ class MinimalHelpCommand(HelpCommand):
no_category = f'\u200b{self.no_category}'
def get_category(command, *, no_category=no_category):
def get_category(command: Command[Any, ..., Any], *, no_category: str = no_category) -> str:
cog = command.cog
return cog.qualified_name if cog is not None else no_category
@ -1290,7 +1341,7 @@ class MinimalHelpCommand(HelpCommand):
await self.send_pages()
async def send_cog_help(self, cog):
async def send_cog_help(self, cog: Cog) -> None:
bot = self.context.bot
if bot.description:
self.paginator.add_line(bot.description, empty=True)
@ -1315,7 +1366,7 @@ class MinimalHelpCommand(HelpCommand):
await self.send_pages()
async def send_group_help(self, group):
async def send_group_help(self, group: Group[Any, ..., Any]) -> None:
self.add_command_formatting(group)
filtered = await self.filter_commands(group.commands, sort=self.sort_commands)
@ -1335,7 +1386,7 @@ class MinimalHelpCommand(HelpCommand):
await self.send_pages()
async def send_command_help(self, command):
async def send_command_help(self, command: Command[Any, ..., Any]) -> None:
self.add_command_formatting(command)
self.paginator.close_page()
await self.send_pages()

37
discord/ext/commands/view.py

@ -21,6 +21,11 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Optional
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
# map from opening quotes to closing quotes
@ -47,24 +52,24 @@ _all_quotes = set(_quotes.keys()) | set(_quotes.values())
class StringView:
def __init__(self, buffer):
self.index = 0
self.buffer = buffer
self.end = len(buffer)
def __init__(self, buffer: str) -> None:
self.index: int = 0
self.buffer: str = buffer
self.end: int = len(buffer)
self.previous = 0
@property
def current(self):
def current(self) -> Optional[str]:
return None if self.eof else self.buffer[self.index]
@property
def eof(self):
def eof(self) -> bool:
return self.index >= self.end
def undo(self):
def undo(self) -> None:
self.index = self.previous
def skip_ws(self):
def skip_ws(self) -> bool:
pos = 0
while not self.eof:
try:
@ -79,7 +84,7 @@ class StringView:
self.index += pos
return self.previous != self.index
def skip_string(self, string):
def skip_string(self, string: str) -> bool:
strlen = len(string)
if self.buffer[self.index : self.index + strlen] == string:
self.previous = self.index
@ -87,19 +92,19 @@ class StringView:
return True
return False
def read_rest(self):
def read_rest(self) -> str:
result = self.buffer[self.index :]
self.previous = self.index
self.index = self.end
return result
def read(self, n):
def read(self, n: int) -> str:
result = self.buffer[self.index : self.index + n]
self.previous = self.index
self.index += n
return result
def get(self):
def get(self) -> Optional[str]:
try:
result = self.buffer[self.index + 1]
except IndexError:
@ -109,7 +114,7 @@ class StringView:
self.index += 1
return result
def get_word(self):
def get_word(self) -> str:
pos = 0
while not self.eof:
try:
@ -119,12 +124,12 @@ class StringView:
pos += 1
except IndexError:
break
self.previous = self.index
self.previous: int = self.index
result = self.buffer[self.index : self.index + pos]
self.index += pos
return result
def get_quoted_word(self):
def get_quoted_word(self) -> Optional[str]:
current = self.current
if current is None:
return None
@ -187,5 +192,5 @@ class StringView:
result.append(current)
def __repr__(self):
def __repr__(self) -> str:
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>'

8
discord/ext/tasks/__init__.py

@ -110,15 +110,15 @@ class SleepHandle:
__slots__ = ('future', 'loop', 'handle')
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop
self.future = future = loop.create_future()
self.loop: asyncio.AbstractEventLoop = loop
self.future: asyncio.Future[None] = loop.create_future()
relative_delta = discord.utils.compute_timedelta(dt)
self.handle = loop.call_later(relative_delta, future.set_result, True)
self.handle = loop.call_later(relative_delta, self.future.set_result, True)
def recalculate(self, dt: datetime.datetime) -> None:
self.handle.cancel()
relative_delta = discord.utils.compute_timedelta(dt)
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
self.handle: asyncio.TimerHandle = self.loop.call_later(relative_delta, self.future.set_result, True)
def wait(self) -> asyncio.Future[Any]:
return self.future

2
discord/file.py

@ -74,7 +74,7 @@ class File:
def __init__(
self,
fp: Union[str, bytes, os.PathLike, io.BufferedIOBase],
fp: Union[str, bytes, os.PathLike[Any], io.BufferedIOBase],
filename: Optional[str] = None,
*,
spoiler: bool = False,

20
discord/flags.py

@ -46,8 +46,8 @@ BF = TypeVar('BF', bound='BaseFlags')
class flag_value:
def __init__(self, func: Callable[[Any], int]):
self.flag = func(None)
self.__doc__ = func.__doc__
self.flag: int = func(None)
self.__doc__: Optional[str] = func.__doc__
@overload
def __get__(self, instance: None, owner: Type[BF]) -> Self:
@ -65,7 +65,7 @@ class flag_value:
def __set__(self, instance: BaseFlags, value: bool) -> None:
instance._set_flag(self.flag, value)
def __repr__(self):
def __repr__(self) -> str:
return f'<flag_value flag={self.flag!r}>'
@ -73,8 +73,8 @@ class alias_flag_value(flag_value):
pass
def fill_with_flags(*, inverted: bool = False):
def decorator(cls: Type[BF]):
def fill_with_flags(*, inverted: bool = False) -> Callable[[Type[BF]], Type[BF]]:
def decorator(cls: Type[BF]) -> Type[BF]:
# fmt: off
cls.VALID_FLAGS = {
name: value.flag
@ -116,10 +116,10 @@ class BaseFlags:
self.value = value
return self
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.value == other.value
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
@ -504,8 +504,8 @@ class Intents(BaseFlags):
__slots__ = ()
def __init__(self, **kwargs: bool):
self.value = self.DEFAULT_VALUE
def __init__(self, **kwargs: bool) -> None:
self.value: int = self.DEFAULT_VALUE
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid flag name.')
@ -1005,7 +1005,7 @@ class MemberCacheFlags(BaseFlags):
def __init__(self, **kwargs: bool):
bits = max(self.VALID_FLAGS.values()).bit_length()
self.value = (1 << bits) - 1
self.value: int = (1 << bits) - 1
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid flag name.')

30
discord/gateway.py

@ -54,6 +54,8 @@ __all__ = (
)
if TYPE_CHECKING:
from typing_extensions import Self
from .client import Client
from .state import ConnectionState
from .voice_client import VoiceClient
@ -62,10 +64,10 @@ if TYPE_CHECKING:
class ReconnectWebSocket(Exception):
"""Signals to safely reconnect the websocket."""
def __init__(self, shard_id, *, resume=True):
self.shard_id = shard_id
self.resume = resume
self.op = 'RESUME' if resume else 'IDENTIFY'
def __init__(self, shard_id: Optional[int], *, resume: bool = True) -> None:
self.shard_id: Optional[int] = shard_id
self.resume: bool = resume
self.op: str = 'RESUME' if resume else 'IDENTIFY'
class WebSocketClosure(Exception):
@ -225,7 +227,7 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
ack_time = time.perf_counter()
self._last_ack = ack_time
self._last_recv = ack_time
self.latency = ack_time - self._last_send
self.latency: float = ack_time - self._last_send
self.recent_ack_latencies.append(self.latency)
@ -339,7 +341,7 @@ class DiscordWebSocket:
@classmethod
async def from_client(
cls: Type[DWS],
cls,
client: Client,
*,
initial: bool = False,
@ -348,7 +350,7 @@ class DiscordWebSocket:
session: Optional[str] = None,
sequence: Optional[int] = None,
resume: bool = False,
) -> DWS:
) -> Self:
"""Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only.
@ -821,11 +823,11 @@ class DiscordVoiceWebSocket:
*,
hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
) -> None:
self.ws = socket
self.loop = loop
self._keep_alive = None
self._close_code = None
self.secret_key = None
self.ws: aiohttp.ClientWebSocketResponse = socket
self.loop: asyncio.AbstractEventLoop = loop
self._keep_alive: Optional[VoiceKeepAliveHandler] = None
self._close_code: Optional[int] = None
self.secret_key: Optional[str] = None
if hook:
self._hook = hook # type: ignore - type-checker doesn't like overriding methods
@ -864,7 +866,9 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload)
@classmethod
async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume=False, hook=None) -> DVWS:
async def from_client(
cls, client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None
) -> Self:
"""Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4'
http = client._state.http

5
discord/guild.py

@ -123,6 +123,7 @@ if TYPE_CHECKING:
)
from .types.integration import IntegrationType
from .types.snowflake import SnowflakeList
from .types.widget import EditWidgetSettings
VocalGuildChannel = Union[VoiceChannel, StageChannel]
GuildChannel = Union[VocalGuildChannel, TextChannel, CategoryChannel, StoreChannel]
@ -3379,7 +3380,7 @@ class Guild(Hashable):
HTTPException
Editing the widget failed.
"""
payload = {}
payload: EditWidgetSettings = {}
if channel is not MISSING:
payload['channel_id'] = None if channel is None else channel.id
if enabled is not MISSING:
@ -3492,7 +3493,7 @@ class Guild(Hashable):
async def change_voice_state(
self, *, channel: Optional[abc.Snowflake], self_mute: bool = False, self_deaf: bool = False
):
) -> None:
"""|coro|
Changes client's voice state in the guild.

22
discord/http.py

@ -76,12 +76,9 @@ if TYPE_CHECKING:
audit_log,
channel,
command,
components,
emoji,
embed,
guild,
integration,
interactions,
invite,
member,
message,
@ -92,7 +89,6 @@ if TYPE_CHECKING:
channel,
widget,
threads,
voice,
scheduled_event,
sticker,
)
@ -122,7 +118,7 @@ class MultipartParameters(NamedTuple):
multipart: Optional[List[Dict[str, Any]]]
files: Optional[List[File]]
def __enter__(self):
def __enter__(self) -> Self:
return self
def __exit__(
@ -577,7 +573,7 @@ class HTTPClient:
return self.request(Route('POST', '/users/{user_id}/channels', user_id=user_id), json=payload)
def leave_group(self, channel_id) -> Response[None]:
def leave_group(self, channel_id: Snowflake) -> Response[None]:
return self.request(Route('DELETE', '/channels/{channel_id}', channel_id=channel_id))
# Message management
@ -1160,7 +1156,7 @@ class HTTPClient:
def sync_template(self, guild_id: Snowflake, code: str) -> Response[template.Template]:
return self.request(Route('PUT', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code))
def edit_template(self, guild_id: Snowflake, code: str, payload) -> Response[template.Template]:
def edit_template(self, guild_id: Snowflake, code: str, payload: Dict[str, Any]) -> Response[template.Template]:
valid_keys = (
'name',
'description',
@ -1420,7 +1416,9 @@ class HTTPClient:
def get_widget(self, guild_id: Snowflake) -> Response[widget.Widget]:
return self.request(Route('GET', '/guilds/{guild_id}/widget.json', guild_id=guild_id))
def edit_widget(self, guild_id: Snowflake, payload, reason: Optional[str] = None) -> Response[widget.WidgetSettings]:
def edit_widget(
self, guild_id: Snowflake, payload: widget.EditWidgetSettings, reason: Optional[str] = None
) -> Response[widget.WidgetSettings]:
return self.request(Route('PATCH', '/guilds/{guild_id}/widget', guild_id=guild_id), json=payload, reason=reason)
# Invite management
@ -1812,7 +1810,9 @@ class HTTPClient:
)
return self.request(r)
def upsert_global_command(self, application_id: Snowflake, payload) -> Response[command.ApplicationCommand]:
def upsert_global_command(
self, application_id: Snowflake, payload: command.ApplicationCommand
) -> Response[command.ApplicationCommand]:
r = Route('POST', '/applications/{application_id}/commands', application_id=application_id)
return self.request(r, json=payload)
@ -1845,7 +1845,9 @@ class HTTPClient:
)
return self.request(r)
def bulk_upsert_global_commands(self, application_id: Snowflake, payload) -> Response[List[command.ApplicationCommand]]:
def bulk_upsert_global_commands(
self, application_id: Snowflake, payload: List[Dict[str, Any]]
) -> Response[List[command.ApplicationCommand]]:
r = Route('PUT', '/applications/{application_id}/commands', application_id=application_id)
return self.request(r, json=payload)

17
discord/integrations.py

@ -39,6 +39,9 @@ __all__ = (
)
if TYPE_CHECKING:
from .guild import Guild
from .role import Role
from .state import ConnectionState
from .types.integration import (
IntegrationAccount as IntegrationAccountPayload,
Integration as IntegrationPayload,
@ -47,8 +50,6 @@ if TYPE_CHECKING:
IntegrationType,
IntegrationApplication as IntegrationApplicationPayload,
)
from .guild import Guild
from .role import Role
class IntegrationAccount:
@ -109,11 +110,11 @@ class Integration:
)
def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None:
self.guild = guild
self._state = guild._state
self.guild: Guild = guild
self._state: ConnectionState = guild._state
self._from_data(data)
def __repr__(self):
def __repr__(self) -> str:
return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>"
def _from_data(self, data: IntegrationPayload) -> None:
@ -123,7 +124,7 @@ class Integration:
self.account: IntegrationAccount = IntegrationAccount(data['account'])
user = data.get('user')
self.user = User(state=self._state, data=user) if user else None
self.user: Optional[User] = User(state=self._state, data=user) if user else None
self.enabled: bool = data['enabled']
async def delete(self, *, reason: Optional[str] = None) -> None:
@ -319,7 +320,7 @@ class IntegrationApplication:
'user',
)
def __init__(self, *, data: IntegrationApplicationPayload, state):
def __init__(self, *, data: IntegrationApplicationPayload, state: ConnectionState) -> None:
self.id: int = int(data['id'])
self.name: str = data['name']
self.icon: Optional[str] = data['icon']
@ -358,7 +359,7 @@ class BotIntegration(Integration):
def _from_data(self, data: BotIntegrationPayload) -> None:
super()._from_data(data)
self.application = IntegrationApplication(data=data['application'], state=self._state)
self.application: IntegrationApplication = IntegrationApplication(data=data['application'], state=self._state)
def _integration_factory(value: str) -> Tuple[Type[Integration], str]:

7
discord/interactions.py

@ -54,6 +54,9 @@ if TYPE_CHECKING:
Interaction as InteractionPayload,
InteractionData,
)
from .types.webhook import (
Webhook as WebhookPayload,
)
from .client import Client
from .guild import Guild
from .state import ConnectionState
@ -229,7 +232,7 @@ class Interaction:
@utils.cached_slot_property('_cs_followup')
def followup(self) -> Webhook:
""":class:`Webhook`: Returns the follow up webhook for follow up interactions."""
payload = {
payload: WebhookPayload = {
'id': self.application_id,
'type': 3,
'token': self.token,
@ -703,7 +706,7 @@ class InteractionResponse:
self._responded = True
async def send_modal(self, modal: Modal, /):
async def send_modal(self, modal: Modal, /) -> None:
"""|coro|
Responds to this interaction by sending a modal.

4
discord/invite.py

@ -456,7 +456,7 @@ class Invite(Hashable):
guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id)
channel_id = int(data['channel_id'])
if guild is not None:
channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore
channel = guild.get_channel(channel_id) or Object(id=channel_id)
else:
guild = Object(id=guild_id) if guild_id is not None else None
channel = Object(id=channel_id)
@ -539,7 +539,7 @@ class Invite(Hashable):
return self
async def delete(self, *, reason: Optional[str] = None):
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|
Revokes the instant invite.

13
discord/member.py

@ -27,9 +27,8 @@ from __future__ import annotations
import datetime
import inspect
import itertools
import sys
from operator import attrgetter
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union
from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, Type
import discord.abc
@ -207,7 +206,7 @@ class _ClientStatus:
return self
def flatten_user(cls):
def flatten_user(cls: Any) -> Type[Member]:
for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()):
# ignore private/special methods
if attr.startswith('_'):
@ -333,7 +332,7 @@ class Member(discord.abc.Messageable, _UserTag):
default_avatar: Asset
avatar: Optional[Asset]
dm_channel: Optional[DMChannel]
create_dm = User.create_dm
create_dm: Callable[[], Coroutine[Any, Any, DMChannel]]
mutual_guilds: List[Guild]
public_flags: PublicUserFlags
banner: Optional[Asset]
@ -369,10 +368,10 @@ class Member(discord.abc.Messageable, _UserTag):
f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>'
)
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, _UserTag) and other.id == self.id
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
@ -425,7 +424,7 @@ class Member(discord.abc.Messageable, _UserTag):
self._user = member._user
return self
async def _get_channel(self):
async def _get_channel(self) -> DMChannel:
ch = await self.create_dm()
return ch

8
discord/mentions.py

@ -92,10 +92,10 @@ class AllowedMentions:
roles: Union[bool, List[Snowflake]] = default,
replied_user: bool = default,
):
self.everyone = everyone
self.users = users
self.roles = roles
self.replied_user = replied_user
self.everyone: bool = everyone
self.users: Union[bool, List[Snowflake]] = users
self.roles: Union[bool, List[Snowflake]] = roles
self.replied_user: bool = replied_user
@classmethod
def all(cls) -> Self:

17
discord/message.py

@ -40,6 +40,7 @@ from typing import (
Tuple,
ClassVar,
Optional,
Type,
overload,
)
@ -71,7 +72,6 @@ if TYPE_CHECKING:
MessageReference as MessageReferencePayload,
MessageApplication as MessageApplicationPayload,
MessageActivity as MessageActivityPayload,
Reaction as ReactionPayload,
)
from .types.components import Component as ComponentPayload
@ -87,7 +87,7 @@ if TYPE_CHECKING:
from .abc import GuildChannel, PartialMessageableChannel, MessageableChannel
from .components import Component
from .state import ConnectionState
from .channel import TextChannel, GroupChannel, DMChannel
from .channel import TextChannel
from .mentions import AllowedMentions
from .user import User
from .role import Role
@ -95,6 +95,7 @@ if TYPE_CHECKING:
EmojiInputType = Union[Emoji, PartialEmoji, str]
__all__ = (
'Attachment',
'Message',
@ -104,7 +105,7 @@ __all__ = (
)
def convert_emoji_reaction(emoji):
def convert_emoji_reaction(emoji: Union[EmojiInputType, Reaction]) -> str:
if isinstance(emoji, Reaction):
emoji = emoji.emoji
@ -216,7 +217,7 @@ class Attachment(Hashable):
async def save(
self,
fp: Union[io.BufferedIOBase, PathLike],
fp: Union[io.BufferedIOBase, PathLike[Any]],
*,
seek_begin: bool = True,
use_cached: bool = False,
@ -510,7 +511,7 @@ class MessageReference:
to_message_reference_dict = to_dict
def flatten_handlers(cls):
def flatten_handlers(cls: Type[Message]) -> Type[Message]:
prefix = len('_handle_')
handlers = [
(key[prefix:], value)
@ -1036,7 +1037,7 @@ class Message(Hashable):
)
@utils.cached_slot_property('_cs_system_content')
def system_content(self):
def system_content(self) -> Optional[str]:
r""":class:`str`: A property that returns the content that is rendered
regardless of the :attr:`Message.type`.
@ -1657,7 +1658,7 @@ class Message(Hashable):
)
return Thread(guild=self.guild, state=self._state, data=data)
async def reply(self, content: Optional[str] = None, **kwargs) -> Message:
async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message:
"""|coro|
A shortcut method to :meth:`.abc.Messageable.send` to reply to the
@ -1798,7 +1799,7 @@ class PartialMessage(Hashable):
# Also needed for duck typing purposes
# n.b. not exposed
pinned = property(None, lambda x, y: None)
pinned: Any = property(None, lambda x, y: None)
def __repr__(self) -> str:
return f'<PartialMessage id={self.id} channel={self.channel!r}>'

5
discord/opus.py

@ -363,7 +363,7 @@ class Encoder(_OpusStruct):
_lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0)
def set_expected_packet_loss_percent(self, percentage: float) -> None:
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100))))
def encode(self, pcm: bytes, frame_size: int) -> bytes:
max_data_bytes = len(pcm)
@ -373,8 +373,7 @@ class Encoder(_OpusStruct):
ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes)
# array can be initialized with bytes but mypy doesn't know
return array.array('b', data[:ret]).tobytes() # type: ignore
return array.array('b', data[:ret]).tobytes()
class Decoder(_OpusStruct):

15
discord/partial_emoji.py

@ -42,6 +42,7 @@ if TYPE_CHECKING:
from .state import ConnectionState
from datetime import datetime
from .types.message import PartialEmoji as PartialEmojiPayload
from .types.activity import ActivityEmoji
class _EmojiTag:
@ -99,13 +100,13 @@ class PartialEmoji(_EmojiTag, AssetMixin):
id: Optional[int]
def __init__(self, *, name: str, animated: bool = False, id: Optional[int] = None):
self.animated = animated
self.name = name
self.id = id
self.animated: bool = animated
self.name: str = name
self.id: Optional[int] = id
self._state: Optional[ConnectionState] = None
@classmethod
def from_dict(cls, data: Union[PartialEmojiPayload, Dict[str, Any]]) -> Self:
def from_dict(cls, data: Union[PartialEmojiPayload, ActivityEmoji, Dict[str, Any]]) -> Self:
return cls(
animated=data.get('animated', False),
id=utils._get_as_snowflake(data, 'id'),
@ -178,10 +179,10 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return f'<a:{self.name}:{self.id}>'
return f'<:{self.name}:{self.id}>'
def __repr__(self):
def __repr__(self) -> str:
return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>'
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if self.is_unicode_emoji():
return isinstance(other, PartialEmoji) and self.name == other.name
@ -189,7 +190,7 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return self.id == other.id
return False
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:

4
discord/permissions.py

@ -276,7 +276,7 @@ class Permissions(BaseFlags):
# So 0000 OP2 0101 -> 0101
# The OP is base & ~denied.
# The OP2 is base | allowed.
self.value = (self.value & ~deny) | allow
self.value: int = (self.value & ~deny) | allow
@flag_value
def create_instant_invite(self) -> int:
@ -691,7 +691,7 @@ class PermissionOverwrite:
setattr(self, key, value)
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, PermissionOverwrite) and self._values == other._values
def _set(self, key: str, value: Optional[bool]) -> None:

21
discord/player.py

@ -365,12 +365,11 @@ class FFmpegOpusAudio(FFmpegAudio):
bitrate: Optional[int] = None,
codec: Optional[str] = None,
executable: str = 'ffmpeg',
pipe=False,
stderr=None,
before_options=None,
options=None,
pipe: bool = False,
stderr: Optional[IO[bytes]] = None,
before_options: Optional[str] = None,
options: Optional[str] = None,
) -> None:
args = []
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
@ -635,7 +634,13 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
class AudioPlayer(threading.Thread):
DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0
def __init__(self, source: AudioSource, client: VoiceClient, *, after=None):
def __init__(
self,
source: AudioSource,
client: VoiceClient,
*,
after: Optional[Callable[[Optional[Exception]], Any]] = None,
) -> None:
threading.Thread.__init__(self)
self.daemon: bool = True
self.source: AudioSource = source
@ -724,8 +729,8 @@ class AudioPlayer(threading.Thread):
self._speak(SpeakingState.none)
def resume(self, *, update_speaking: bool = True) -> None:
self.loops = 0
self._start = time.perf_counter()
self.loops: int = 0
self._start: float = time.perf_counter()
self._resumed.set()
if update_speaking:
self._speak(SpeakingState.voice)

4
discord/reaction.py

@ -94,10 +94,10 @@ class Reaction:
""":class:`bool`: If this is a custom emoji."""
return not isinstance(self.emoji, str)
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and other.emoji == self.emoji
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
if isinstance(other, self.__class__):
return other.emoji != self.emoji
return True

4
discord/role.py

@ -211,7 +211,7 @@ class Role(Hashable):
def __repr__(self) -> str:
return f'<Role id={self.id} name={self.name!r}>'
def __lt__(self, other: Any) -> bool:
def __lt__(self, other: object) -> bool:
if not isinstance(other, Role) or not isinstance(self, Role):
return NotImplemented
@ -241,7 +241,7 @@ class Role(Hashable):
def __gt__(self, other: Any) -> bool:
return Role.__lt__(other, self)
def __ge__(self, other: Any) -> bool:
def __ge__(self, other: object) -> bool:
r = Role.__lt__(self, other)
if r is NotImplemented:
return NotImplemented

6
discord/scheduled_event.py

@ -132,7 +132,7 @@ class ScheduledEvent(Hashable):
self.guild_id: int = int(data['guild_id'])
self.name: str = data['name']
self.description: Optional[str] = data.get('description')
self.entity_type = try_enum(EntityType, data['entity_type'])
self.entity_type: EntityType = try_enum(EntityType, data['entity_type'])
self.entity_id: Optional[int] = _get_as_snowflake(data, 'entity_id')
self.start_time: datetime = parse_time(data['scheduled_start_time'])
self.privacy_level: PrivacyLevel = try_enum(PrivacyLevel, data['status'])
@ -153,7 +153,7 @@ class ScheduledEvent(Hashable):
self.location: Optional[str] = data.get('location') if data else None
@classmethod
def from_creation(cls, *, state: ConnectionState, data: GuildScheduledEventPayload):
def from_creation(cls, *, state: ConnectionState, data: GuildScheduledEventPayload) -> None:
creator_id = data.get('creator_id')
self = cls(state=state, data=data)
if creator_id:
@ -180,7 +180,7 @@ class ScheduledEvent(Hashable):
return self.guild.get_channel(self.channel_id) # type: ignore
@property
def url(self):
def url(self) -> str:
""":class:`str`: The url for the scheduled event."""
return f'https://discord.com/events/{self.guild_id}/{self.id}'

5
discord/shard.py

@ -75,12 +75,12 @@ class EventItem:
self.shard: Optional['Shard'] = shard
self.error: Optional[Exception] = error
def __lt__(self, other: Any) -> bool:
def __lt__(self, other: object) -> bool:
if not isinstance(other, EventItem):
return NotImplemented
return self.type < other.type
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if not isinstance(other, EventItem):
return NotImplemented
return self.type == other.type
@ -409,6 +409,7 @@ class AutoShardedClient(Client):
async def launch_shards(self) -> None:
if self.shard_count is None:
self.shard_count: int
self.shard_count, gateway = await self.http.get_bot_gateway()
else:
gateway = await self.http.get_gateway()

6
discord/stage_instance.py

@ -97,11 +97,11 @@ class StageInstance(Hashable):
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None:
self._state = state
self.guild = guild
self._state: ConnectionState = state
self.guild: Guild = guild
self._update(data)
def _update(self, data: StageInstancePayload):
def _update(self, data: StageInstancePayload) -> None:
self.id: int = int(data['id'])
self.channel_id: int = int(data['channel_id'])
self.topic: str = data['topic']

58
discord/state.py

@ -43,6 +43,8 @@ from typing import (
Sequence,
Tuple,
Deque,
Literal,
overload,
)
import weakref
import inspect
@ -88,7 +90,7 @@ if TYPE_CHECKING:
from .types.activity import Activity as ActivityPayload
from .types.channel import DMChannel as DMChannelPayload
from .types.user import User as UserPayload, PartialUser as PartialUserPayload
from .types.emoji import Emoji as EmojiPayload
from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
from .types.sticker import GuildSticker as GuildStickerPayload
from .types.guild import Guild as GuildPayload
from .types.message import Message as MessagePayload, PartialMessage as PartialMessagePayload
@ -165,9 +167,9 @@ class ConnectionState:
def __init__(
self,
*,
dispatch: Callable,
handlers: Dict[str, Callable],
hooks: Dict[str, Callable],
dispatch: Callable[..., Any],
handlers: Dict[str, Callable[..., Any]],
hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]],
http: HTTPClient,
**options: Any,
) -> None:
@ -178,9 +180,9 @@ class ConnectionState:
if self.max_messages is not None and self.max_messages <= 0:
self.max_messages = 1000
self.dispatch: Callable = dispatch
self.handlers: Dict[str, Callable] = handlers
self.hooks: Dict[str, Callable] = hooks
self.dispatch: Callable[..., Any] = dispatch
self.handlers: Dict[str, Callable[..., Any]] = handlers
self.hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = hooks
self.shard_count: Optional[int] = None
self._ready_task: Optional[asyncio.Task] = None
self.application_id: Optional[int] = utils._get_as_snowflake(options, 'application_id')
@ -245,6 +247,7 @@ class ConnectionState:
if not intents.members or cache_flags._empty:
self.store_user = self.store_user_no_intents # type: ignore - This reassignment is on purpose
self.parsers: Dict[str, Callable[[Any], None]]
self.parsers = parsers = {}
for attr, func in inspect.getmembers(self):
if attr.startswith('parse_'):
@ -343,13 +346,13 @@ class ConnectionState:
self._users[user_id] = user
return user
def store_user_no_intents(self, data):
def store_user_no_intents(self, data: Union[UserPayload, PartialUserPayload]) -> User:
return User(state=self, data=data)
def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> User:
return User(state=self, data=data)
def get_user(self, id):
def get_user(self, id: int) -> Optional[User]:
return self._users.get(id)
def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji:
@ -571,8 +574,7 @@ class ConnectionState:
pass
else:
self.application_id = utils._get_as_snowflake(application, 'id')
# flags will always be present here
self.application_flags = ApplicationFlags._from_value(application['flags'])
self.application_flags: ApplicationFlags = ApplicationFlags._from_value(application['flags'])
for guild_data in data['guilds']:
self._add_guild_from_data(guild_data) # type: ignore
@ -743,7 +745,7 @@ class ConnectionState:
self.dispatch('presence_update', old_member, member)
def parse_user_update(self, data: gw.UserUpdateEvent):
def parse_user_update(self, data: gw.UserUpdateEvent) -> None:
if self.user:
self.user._update(data)
@ -1050,7 +1052,7 @@ class ConnectionState:
guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers']))
self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers)
def _get_create_guild(self, data):
def _get_create_guild(self, data: gw.GuildCreateEvent) -> Guild:
if data.get('unavailable') is False:
# GUILD_CREATE with unavailable in the response
# usually means that the guild has become available
@ -1063,10 +1065,22 @@ class ConnectionState:
return self._add_guild_from_data(data)
def is_guild_evicted(self, guild) -> bool:
def is_guild_evicted(self, guild: Guild) -> bool:
return guild.id not in self._guilds
async def chunk_guild(self, guild, *, wait=True, cache=None):
@overload
async def chunk_guild(self, guild: Guild, *, wait: Literal[True] = ..., cache: Optional[bool] = ...) -> List[Member]:
...
@overload
async def chunk_guild(
self, guild: Guild, *, wait: Literal[False] = ..., cache: Optional[bool] = ...
) -> asyncio.Future[List[Member]]:
...
async def chunk_guild(
self, guild: Guild, *, wait: bool = True, cache: Optional[bool] = None
) -> Union[List[Member], asyncio.Future[List[Member]]]:
cache = cache or self.member_cache_flags.joined
request = self._chunk_requests.get(guild.id)
if request is None:
@ -1445,16 +1459,19 @@ class ConnectionState:
return channel.guild.get_member(user_id)
return self.get_user(user_id)
def get_reaction_emoji(self, data) -> Union[Emoji, PartialEmoji]:
def get_reaction_emoji(self, data: PartialEmojiPayload) -> Union[Emoji, PartialEmoji, str]:
emoji_id = utils._get_as_snowflake(data, 'id')
if not emoji_id:
return data['name']
# the name key will be a str
return data['name'] # type: ignore
try:
return self._emojis[emoji_id]
except KeyError:
return PartialEmoji.with_state(self, animated=data.get('animated', False), id=emoji_id, name=data['name'])
return PartialEmoji.with_state(
self, animated=data.get('animated', False), id=emoji_id, name=data['name'] # type: ignore
)
def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmoji, str]:
emoji_id = emoji.id
@ -1589,6 +1606,7 @@ class AutoShardedConnectionState(ConnectionState):
if not hasattr(self, '_ready_state'):
self._ready_state = asyncio.Queue()
self.user: Optional[ClientUser]
self.user = user = ClientUser(state=self, data=data['user'])
# self._users is a list of Users, we're setting a ClientUser
self._users[user.id] = user # type: ignore
@ -1599,8 +1617,8 @@ class AutoShardedConnectionState(ConnectionState):
except KeyError:
pass
else:
self.application_id = utils._get_as_snowflake(application, 'id')
self.application_flags = ApplicationFlags._from_value(application['flags'])
self.application_id: Optional[int] = utils._get_as_snowflake(application, 'id')
self.application_flags: ApplicationFlags = ApplicationFlags._from_value(application['flags'])
for guild_data in data['guilds']:
self._add_guild_from_data(guild_data) # type: ignore - _add_guild_from_data requires a complete Guild payload

2
discord/sticker.py

@ -228,7 +228,7 @@ class StickerItem(_StickerTag):
The retrieved sticker.
"""
data: StickerPayload = await self._state.http.get_sticker(self.id)
cls, _ = _sticker_factory(data['type']) # type: ignore
cls, _ = _sticker_factory(data['type'])
return cls(state=self._state, data=data)

71
discord/threads.py

@ -41,6 +41,8 @@ __all__ = (
)
if TYPE_CHECKING:
from typing_extensions import Self
from .types.threads import (
Thread as ThreadPayload,
ThreadMember as ThreadMemberPayload,
@ -147,13 +149,13 @@ class Thread(Messageable, Hashable):
'_created_at',
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload) -> None:
self._state: ConnectionState = state
self.guild = guild
self.guild: Guild = guild
self._members: Dict[int, ThreadMember] = {}
self._from_data(data)
async def _get_channel(self):
async def _get_channel(self) -> Self:
return self
def __repr__(self) -> str:
@ -166,17 +168,18 @@ class Thread(Messageable, Hashable):
return self.name
def _from_data(self, data: ThreadPayload):
self.id = int(data['id'])
self.parent_id = int(data['parent_id'])
self.owner_id = int(data['owner_id'])
self.name = data['name']
self._type = try_enum(ChannelType, data['type'])
self.last_message_id = _get_as_snowflake(data, 'last_message_id')
self.slowmode_delay = data.get('rate_limit_per_user', 0)
self.message_count = data['message_count']
self.member_count = data['member_count']
self.id: int = int(data['id'])
self.parent_id: int = int(data['parent_id'])
self.owner_id: int = int(data['owner_id'])
self.name: str = data['name']
self._type: ChannelType = try_enum(ChannelType, data['type'])
self.last_message_id: Optional[int] = _get_as_snowflake(data, 'last_message_id')
self.slowmode_delay: int = data.get('rate_limit_per_user', 0)
self.message_count: int = data['message_count']
self.member_count: int = data['member_count']
self._unroll_metadata(data['thread_metadata'])
self.me: Optional[ThreadMember]
try:
member = data['member']
except KeyError:
@ -185,15 +188,15 @@ class Thread(Messageable, Hashable):
self.me = ThreadMember(self, member)
def _unroll_metadata(self, data: ThreadMetadata):
self.archived = data['archived']
self.archiver_id = _get_as_snowflake(data, 'archiver_id')
self.auto_archive_duration = data['auto_archive_duration']
self.archive_timestamp = parse_time(data['archive_timestamp'])
self.locked = data.get('locked', False)
self.invitable = data.get('invitable', True)
self._created_at = parse_time(data.get('create_timestamp'))
def _update(self, data):
self.archived: bool = data['archived']
self.archiver_id: Optional[int] = _get_as_snowflake(data, 'archiver_id')
self.auto_archive_duration: int = data['auto_archive_duration']
self.archive_timestamp: datetime = parse_time(data['archive_timestamp'])
self.locked: bool = data.get('locked', False)
self.invitable: bool = data.get('invitable', True)
self._created_at: Optional[datetime] = parse_time(data.get('create_timestamp'))
def _update(self, data: ThreadPayload) -> None:
try:
self.name = data['name']
except KeyError:
@ -602,7 +605,7 @@ class Thread(Messageable, Hashable):
# The data payload will always be a Thread payload
return Thread(data=data, state=self._state, guild=self.guild) # type: ignore
async def join(self):
async def join(self) -> None:
"""|coro|
Joins this thread.
@ -619,7 +622,7 @@ class Thread(Messageable, Hashable):
"""
await self._state.http.join_thread(self.id)
async def leave(self):
async def leave(self) -> None:
"""|coro|
Leaves this thread.
@ -631,7 +634,7 @@ class Thread(Messageable, Hashable):
"""
await self._state.http.leave_thread(self.id)
async def add_user(self, user: Snowflake, /):
async def add_user(self, user: Snowflake, /) -> None:
"""|coro|
Adds a user to this thread.
@ -654,7 +657,7 @@ class Thread(Messageable, Hashable):
"""
await self._state.http.add_user_to_thread(self.id, user.id)
async def remove_user(self, user: Snowflake, /):
async def remove_user(self, user: Snowflake, /) -> None:
"""|coro|
Removes a user from this thread.
@ -718,7 +721,7 @@ class Thread(Messageable, Hashable):
members = await self._state.http.get_thread_members(self.id)
return [ThreadMember(parent=self, data=data) for data in members]
async def delete(self):
async def delete(self) -> None:
"""|coro|
Deletes this thread.
@ -806,28 +809,28 @@ class ThreadMember(Hashable):
'parent',
)
def __init__(self, parent: Thread, data: ThreadMemberPayload):
self.parent = parent
self._state = parent._state
def __init__(self, parent: Thread, data: ThreadMemberPayload) -> None:
self.parent: Thread = parent
self._state: ConnectionState = parent._state
self._from_data(data)
def __repr__(self) -> str:
return f'<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>'
def _from_data(self, data: ThreadMemberPayload):
def _from_data(self, data: ThreadMemberPayload) -> None:
try:
self.id = int(data['user_id'])
except KeyError:
assert self._state.self_id is not None
self.id = self._state.self_id
self.id = self._state.self_id # type: ignore
self.thread_id: int
try:
self.thread_id = int(data['id'])
except KeyError:
self.thread_id = self.parent.id
self.joined_at = parse_time(data['join_timestamp'])
self.flags = data['flags']
self.joined_at: datetime = parse_time(data['join_timestamp'])
self.flags: int = data['flags']
@property
def thread(self) -> Thread:

1
discord/types/activity.py

@ -112,3 +112,4 @@ class Activity(_BaseActivity, total=False):
session_id: Optional[str]
instance: bool
buttons: List[ActivityButton]
sync_id: str

5
discord/types/widget.py

@ -58,3 +58,8 @@ class Widget(TypedDict):
class WidgetSettings(TypedDict):
enabled: bool
channel_id: Optional[Snowflake]
class EditWidgetSettings(TypedDict, total=False):
enabled: bool
channel_id: Optional[Snowflake]

17
discord/ui/button.py

@ -44,6 +44,7 @@ if TYPE_CHECKING:
from .view import View
from ..emoji import Emoji
from ..types.components import ButtonComponent as ButtonComponentPayload
V = TypeVar('V', bound='View', covariant=True)
@ -124,7 +125,7 @@ class Button(Item[V]):
style=style,
emoji=emoji,
)
self.row = row
self.row: Optional[int] = row
@property
def style(self) -> ButtonStyle:
@ -132,7 +133,7 @@ class Button(Item[V]):
return self._underlying.style
@style.setter
def style(self, value: ButtonStyle):
def style(self, value: ButtonStyle) -> None:
self._underlying.style = value
@property
@ -144,7 +145,7 @@ class Button(Item[V]):
return self._underlying.custom_id
@custom_id.setter
def custom_id(self, value: Optional[str]):
def custom_id(self, value: Optional[str]) -> None:
if value is not None and not isinstance(value, str):
raise TypeError('custom_id must be None or str')
@ -156,7 +157,7 @@ class Button(Item[V]):
return self._underlying.url
@url.setter
def url(self, value: Optional[str]):
def url(self, value: Optional[str]) -> None:
if value is not None and not isinstance(value, str):
raise TypeError('url must be None or str')
self._underlying.url = value
@ -167,7 +168,7 @@ class Button(Item[V]):
return self._underlying.disabled
@disabled.setter
def disabled(self, value: bool):
def disabled(self, value: bool) -> None:
self._underlying.disabled = bool(value)
@property
@ -176,7 +177,7 @@ class Button(Item[V]):
return self._underlying.label
@label.setter
def label(self, value: Optional[str]):
def label(self, value: Optional[str]) -> None:
self._underlying.label = str(value) if value is not None else value
@property
@ -185,7 +186,7 @@ class Button(Item[V]):
return self._underlying.emoji
@emoji.setter
def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore
def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]) -> None:
if value is not None:
if isinstance(value, str):
self._underlying.emoji = PartialEmoji.from_str(value)
@ -212,7 +213,7 @@ class Button(Item[V]):
def type(self) -> ComponentType:
return self._underlying.type
def to_component_dict(self):
def to_component_dict(self) -> ButtonComponentPayload:
return self._underlying.to_dict()
def is_dispatchable(self) -> bool:

4
discord/ui/item.py

@ -101,7 +101,7 @@ class Item(Generic[V]):
return self._row
@row.setter
def row(self, value: Optional[int]):
def row(self, value: Optional[int]) -> None:
if value is None:
self._row = None
elif 5 > value >= 0:
@ -118,7 +118,7 @@ class Item(Generic[V]):
"""Optional[:class:`View`]: The underlying view for this item."""
return self._view
async def callback(self, interaction: Interaction):
async def callback(self, interaction: Interaction) -> Any:
"""|coro|
The callback associated with this UI item.

8
discord/ui/modal.py

@ -38,6 +38,8 @@ from .item import Item
from .view import View
if TYPE_CHECKING:
from typing_extensions import Self
from ..interactions import Interaction
from ..types.interactions import ModalSubmitComponentInteractionData as ModalSubmitComponentInteractionDataPayload
@ -101,7 +103,7 @@ class Modal(View):
title: str
__discord_ui_modal__ = True
__modal_children_items__: ClassVar[Dict[str, Item]] = {}
__modal_children_items__: ClassVar[Dict[str, Item[Self]]] = {}
def __init_subclass__(cls, *, title: str = MISSING) -> None:
if title is not MISSING:
@ -139,7 +141,7 @@ class Modal(View):
super().__init__(timeout=timeout)
async def on_submit(self, interaction: Interaction):
async def on_submit(self, interaction: Interaction) -> None:
"""|coro|
Called when the modal is submitted.
@ -169,7 +171,7 @@ class Modal(View):
print(f'Ignoring exception in modal {self}:', file=sys.stderr)
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
def refresh(self, components: Sequence[ModalSubmitComponentInteractionDataPayload]):
def refresh(self, components: Sequence[ModalSubmitComponentInteractionDataPayload]) -> None:
for component in components:
if component['type'] == 1:
self.refresh(component['components'])

18
discord/ui/select.py

@ -121,7 +121,7 @@ class Select(Item[V]):
options=options,
disabled=disabled,
)
self.row = row
self.row: Optional[int] = row
@property
def custom_id(self) -> str:
@ -129,7 +129,7 @@ class Select(Item[V]):
return self._underlying.custom_id
@custom_id.setter
def custom_id(self, value: str):
def custom_id(self, value: str) -> None:
if not isinstance(value, str):
raise TypeError('custom_id must be None or str')
@ -141,7 +141,7 @@ class Select(Item[V]):
return self._underlying.placeholder
@placeholder.setter
def placeholder(self, value: Optional[str]):
def placeholder(self, value: Optional[str]) -> None:
if value is not None and not isinstance(value, str):
raise TypeError('placeholder must be None or str')
@ -153,7 +153,7 @@ class Select(Item[V]):
return self._underlying.min_values
@min_values.setter
def min_values(self, value: int):
def min_values(self, value: int) -> None:
self._underlying.min_values = int(value)
@property
@ -162,7 +162,7 @@ class Select(Item[V]):
return self._underlying.max_values
@max_values.setter
def max_values(self, value: int):
def max_values(self, value: int) -> None:
self._underlying.max_values = int(value)
@property
@ -171,7 +171,7 @@ class Select(Item[V]):
return self._underlying.options
@options.setter
def options(self, value: List[SelectOption]):
def options(self, value: List[SelectOption]) -> None:
if not isinstance(value, list):
raise TypeError('options must be a list of SelectOption')
if not all(isinstance(obj, SelectOption) for obj in value):
@ -187,7 +187,7 @@ class Select(Item[V]):
description: Optional[str] = None,
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
default: bool = False,
):
) -> None:
"""Adds an option to the select menu.
To append a pre-existing :class:`discord.SelectOption` use the
@ -226,7 +226,7 @@ class Select(Item[V]):
self.append_option(option)
def append_option(self, option: SelectOption):
def append_option(self, option: SelectOption) -> None:
"""Appends an option to the select menu.
Parameters
@ -251,7 +251,7 @@ class Select(Item[V]):
return self._underlying.disabled
@disabled.setter
def disabled(self, value: bool):
def disabled(self, value: bool) -> None:
self._underlying.disabled = bool(value)
@property

30
discord/ui/view.py

@ -50,6 +50,8 @@ __all__ = (
if TYPE_CHECKING:
from typing_extensions import Self
from ..interactions import Interaction
from ..message import Message
from ..types.components import Component as ComponentPayload
@ -163,7 +165,7 @@ class View:
cls.__view_children_items__ = children
def _init_children(self) -> List[Item]:
def _init_children(self) -> List[Item[Self]]:
children = []
for func in self.__view_children_items__:
item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__)
@ -175,7 +177,7 @@ class View:
def __init__(self, *, timeout: Optional[float] = 180.0):
self.timeout = timeout
self.children: List[Item] = self._init_children()
self.children: List[Item[Self]] = self._init_children()
self.__weights = _ViewWeights(self.children)
self.id: str = os.urandom(16).hex()
self.__cancel_callback: Optional[Callable[[View], None]] = None
@ -250,7 +252,7 @@ class View:
view.add_item(_component_to_item(component))
return view
def add_item(self, item: Item) -> None:
def add_item(self, item: Item[Any]) -> None:
"""Adds an item to the view.
Parameters
@ -278,7 +280,7 @@ class View:
item._view = self
self.children.append(item)
def remove_item(self, item: Item) -> None:
def remove_item(self, item: Item[Any]) -> None:
"""Removes an item from the view.
Parameters
@ -334,7 +336,7 @@ class View:
"""
pass
async def on_error(self, error: Exception, item: Item, interaction: Interaction) -> None:
async def on_error(self, error: Exception, item: Item[Any], interaction: Interaction) -> None:
"""|coro|
A callback that is called when an item's callback or :meth:`interaction_check`
@ -395,16 +397,16 @@ class View:
asyncio.create_task(self._scheduled_task(item, interaction), name=f'discord-ui-view-dispatch-{self.id}')
def refresh(self, components: List[Component]):
def refresh(self, components: List[Component]) -> None:
# This is pretty hacky at the moment
# fmt: off
old_state: Dict[Tuple[int, str], Item] = {
old_state: Dict[Tuple[int, str], Item[Any]] = {
(item.type.value, item.custom_id): item # type: ignore
for item in self.children
if item.is_dispatchable()
}
# fmt: on
children: List[Item] = []
children: List[Item[Any]] = []
for component in _walk_all_components(components):
try:
older = old_state[(component.type.value, component.custom_id)] # type: ignore
@ -494,7 +496,7 @@ class ViewStore:
for k in to_remove:
del self._views[k]
def add_view(self, view: View, message_id: Optional[int] = None):
def add_view(self, view: View, message_id: Optional[int] = None) -> None:
view._start_listening_from_store(self)
if view.__discord_ui_modal__:
self._modals[view.custom_id] = view # type: ignore
@ -509,7 +511,7 @@ class ViewStore:
if message_id is not None:
self._synced_message_views[message_id] = view
def remove_view(self, view: View):
def remove_view(self, view: View) -> None:
if view.__discord_ui_modal__:
self._modals.pop(view.custom_id, None) # type: ignore
return
@ -523,7 +525,7 @@ class ViewStore:
del self._synced_message_views[key]
break
def dispatch_view(self, component_type: int, custom_id: str, interaction: Interaction):
def dispatch_view(self, component_type: int, custom_id: str, interaction: Interaction) -> None:
self.__verify_integrity()
message_id: Optional[int] = interaction.message and interaction.message.id
key = (component_type, message_id, custom_id)
@ -542,7 +544,7 @@ class ViewStore:
custom_id: str,
interaction: Interaction,
components: List[ModalSubmitComponentInteractionDataPayload],
):
) -> None:
modal = self._modals.get(custom_id)
if modal is None:
_log.debug("Modal interaction referencing unknown custom_id %s. Discarding", custom_id)
@ -551,13 +553,13 @@ class ViewStore:
modal.refresh(components)
modal._dispatch_submit(interaction)
def is_message_tracked(self, message_id: int):
def is_message_tracked(self, message_id: int) -> bool:
return message_id in self._synced_message_views
def remove_message_tracking(self, message_id: int) -> Optional[View]:
return self._synced_message_views.pop(message_id, None)
def update_from_message(self, message_id: int, components: List[ComponentPayload]):
def update_from_message(self, message_id: int, components: List[ComponentPayload]) -> None:
# pre-req: is_message_tracked == true
view = self._synced_message_views[message_id]
view.refresh([_component_factory(d) for d in components])

6
discord/user.py

@ -99,10 +99,10 @@ class BaseUser(_UserTag):
def __str__(self) -> str:
return f'{self.name}#{self.discriminator}'
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, _UserTag) and other.id == self.id
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
@ -444,7 +444,7 @@ class User(BaseUser, discord.abc.Messageable):
def __repr__(self) -> str:
return f'<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>'
async def _get_channel(self):
async def _get_channel(self) -> DMChannel:
ch = await self.create_dm()
return ch

36
discord/utils.py

@ -29,6 +29,7 @@ from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
@ -42,6 +43,7 @@ from typing import (
NamedTuple,
Optional,
Protocol,
Set,
Sequence,
Tuple,
Type,
@ -66,7 +68,7 @@ import warnings
import yarl
try:
import orjson
import orjson # type: ignore
except ModuleNotFoundError:
HAS_ORJSON = False
else:
@ -123,7 +125,7 @@ class _cached_property:
if TYPE_CHECKING:
from functools import cached_property as cached_property
from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, Self
from .permissions import Permissions
from .abc import Snowflake
@ -135,8 +137,16 @@ if TYPE_CHECKING:
P = ParamSpec('P')
MaybeCoroFunc = Union[
Callable[P, Coroutine[Any, Any, 'T']],
Callable[P, 'T'],
]
_SnowflakeListBase = array.array[int]
else:
cached_property = _cached_property
_SnowflakeListBase = array.array
T = TypeVar('T')
@ -178,7 +188,7 @@ class classproperty(Generic[T_co]):
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
return self.fget(owner)
def __set__(self, instance, value) -> None:
def __set__(self, instance: Optional[Any], value: Any) -> None:
raise AttributeError('cannot set attribute')
@ -210,7 +220,7 @@ class SequenceProxy(Sequence[T_co]):
def __reversed__(self) -> Iterator[T_co]:
return reversed(self.__proxied)
def index(self, value: Any, *args, **kwargs) -> int:
def index(self, value: Any, *args: Any, **kwargs: Any) -> int:
return self.__proxied.index(value, *args, **kwargs)
def count(self, value: Any) -> int:
@ -578,7 +588,7 @@ def _is_submodule(parent: str, child: str) -> bool:
if HAS_ORJSON:
def _to_json(obj: Any) -> str: # type: ignore
def _to_json(obj: Any) -> str:
return orjson.dumps(obj).decode('utf-8')
_from_json = orjson.loads # type: ignore
@ -602,15 +612,15 @@ def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
return float(reset_after)
async def maybe_coroutine(f, *args, **kwargs):
async def maybe_coroutine(f: MaybeCoroFunc[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
value = f(*args, **kwargs)
if _isawaitable(value):
return await value
else:
return value
return value # type: ignore
async def async_all(gen, *, check=_isawaitable):
async def async_all(gen: Iterable[Awaitable[T]], *, check: Callable[[T], bool] = _isawaitable) -> bool:
for elem in gen:
if check(elem):
elem = await elem
@ -619,7 +629,7 @@ async def async_all(gen, *, check=_isawaitable):
return True
async def sane_wait_for(futures, *, timeout):
async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: Optional[float]) -> Set[asyncio.Task[T]]:
ensured = [asyncio.ensure_future(fut) for fut in futures]
done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED)
@ -637,7 +647,7 @@ def get_slots(cls: Type[Any]) -> Iterator[str]:
continue
def compute_timedelta(dt: datetime.datetime):
def compute_timedelta(dt: datetime.datetime) -> float:
if dt.tzinfo is None:
dt = dt.astimezone()
now = datetime.datetime.now(datetime.timezone.utc)
@ -686,7 +696,7 @@ def valid_icon_size(size: int) -> bool:
return not size & (size - 1) and 4096 >= size >= 16
class SnowflakeList(array.array):
class SnowflakeList(_SnowflakeListBase):
"""Internal data storage class to efficiently store a list of snowflakes.
This should have the following characteristics:
@ -705,7 +715,7 @@ class SnowflakeList(array.array):
def __init__(self, data: Iterable[int], *, is_sorted: bool = False):
...
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False):
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False) -> Self:
return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore
def add(self, element: int) -> None:
@ -1010,7 +1020,7 @@ def evaluate_annotation(
cache: Dict[str, Any],
*,
implicit_str: bool = True,
):
) -> Any:
if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__
# ForwardRefs always evaluate their internals

10
discord/voice_client.py

@ -262,7 +262,7 @@ class VoiceClient(VoiceProtocol):
self._lite_nonce: int = 0
self.ws: DiscordVoiceWebSocket = MISSING
warn_nacl = not has_nacl
warn_nacl: bool = not has_nacl
supported_modes: Tuple[SupportedModes, ...] = (
'xsalsa20_poly1305_lite',
'xsalsa20_poly1305_suffix',
@ -279,7 +279,7 @@ class VoiceClient(VoiceProtocol):
""":class:`ClientUser`: The user connected to voice (i.e. ourselves)."""
return self._state.user # type: ignore - user can't be None after login
def checked_add(self, attr, value, limit):
def checked_add(self, attr: str, value: int, limit: int) -> None:
val = getattr(self, attr)
if val + value > limit:
setattr(self, attr, 0)
@ -289,7 +289,7 @@ class VoiceClient(VoiceProtocol):
# connection related
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
self.session_id = data['session_id']
self.session_id: str = data['session_id']
channel_id = data['channel_id']
if not self._handshaking or self._potentially_reconnecting:
@ -323,12 +323,12 @@ class VoiceClient(VoiceProtocol):
self.endpoint, _, _ = endpoint.rpartition(':')
if self.endpoint.startswith('wss://'):
# Just in case, strip it off since we're going to add it later
self.endpoint = self.endpoint[6:]
self.endpoint: str = self.endpoint[6:]
# This gets set later
self.endpoint_ip = MISSING
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.setblocking(False)
if not self._handshaking:

122
discord/webhook/async_.py

@ -30,7 +30,7 @@ import json
import re
from urllib.parse import quote as urlquote
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, TypeVar, Type, overload
from contextvars import ContextVar
import weakref
@ -43,7 +43,7 @@ from ..enums import try_enum, WebhookType
from ..user import BaseUser, User
from ..flags import MessageFlags
from ..asset import Asset
from ..http import Route, handle_message_parameters, MultipartParameters
from ..http import Route, handle_message_parameters, MultipartParameters, HTTPClient
from ..mixins import Hashable
from ..channel import PartialMessageable
from ..file import File
@ -58,24 +58,38 @@ __all__ = (
_log = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing_extensions import Self
from types import TracebackType
from ..embeds import Embed
from ..mentions import AllowedMentions
from ..message import Attachment
from ..state import ConnectionState
from ..http import Response
from ..guild import Guild
from ..channel import TextChannel
from ..abc import Snowflake
from ..ui.view import View
import datetime
from ..types.webhook import (
Webhook as WebhookPayload,
SourceGuild as SourceGuildPayload,
)
from ..types.message import (
Message as MessagePayload,
)
from ..guild import Guild
from ..channel import TextChannel
from ..abc import Snowflake
from ..ui.view import View
import datetime
from ..types.user import (
User as UserPayload,
PartialUser as PartialUserPayload,
)
from ..types.channel import (
PartialChannel as PartialChannelPayload,
)
BE = TypeVar('BE', bound=BaseException)
_State = Union[ConnectionState, '_WebhookState']
MISSING = utils.MISSING
MISSING: Any = utils.MISSING
class AsyncDeferredLock:
@ -83,14 +97,19 @@ class AsyncDeferredLock:
self.lock = lock
self.delta: Optional[float] = None
async def __aenter__(self):
async def __aenter__(self) -> Self:
await self.lock.acquire()
return self
def delay_by(self, delta: float) -> None:
self.delta = delta
async def __aexit__(self, type, value, traceback):
async def __aexit__(
self,
exc_type: Optional[Type[BE]],
exc: Optional[BE],
traceback: Optional[TracebackType],
) -> None:
if self.delta:
await asyncio.sleep(self.delta)
self.lock.release()
@ -545,11 +564,11 @@ class PartialWebhookChannel(Hashable):
__slots__ = ('id', 'name')
def __init__(self, *, data):
self.id = int(data['id'])
self.name = data['name']
def __init__(self, *, data: PartialChannelPayload) -> None:
self.id: int = int(data['id'])
self.name: str = data['name']
def __repr__(self):
def __repr__(self) -> str:
return f'<PartialWebhookChannel name={self.name!r} id={self.id}>'
@ -570,13 +589,13 @@ class PartialWebhookGuild(Hashable):
__slots__ = ('id', 'name', '_icon', '_state')
def __init__(self, *, data, state):
self._state = state
self.id = int(data['id'])
self.name = data['name']
self._icon = data['icon']
def __init__(self, *, data: SourceGuildPayload, state: _State) -> None:
self._state: _State = state
self.id: int = int(data['id'])
self.name: str = data['name']
self._icon: str = data['icon']
def __repr__(self):
def __repr__(self) -> str:
return f'<PartialWebhookGuild name={self.name!r} id={self.id}>'
@property
@ -590,14 +609,14 @@ class PartialWebhookGuild(Hashable):
class _FriendlyHttpAttributeErrorHelper:
__slots__ = ()
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> Any:
raise AttributeError('PartialWebhookState does not support http methods.')
class _WebhookState:
__slots__ = ('_parent', '_webhook')
def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]):
def __init__(self, webhook: Any, parent: Optional[_State]):
self._webhook: Any = webhook
self._parent: Optional[ConnectionState]
@ -606,23 +625,23 @@ class _WebhookState:
else:
self._parent = parent
def _get_guild(self, guild_id):
def _get_guild(self, guild_id: Optional[int]) -> Optional[Guild]:
if self._parent is not None:
return self._parent._get_guild(guild_id)
return None
def store_user(self, data):
def store_user(self, data: Union[UserPayload, PartialUserPayload]) -> BaseUser:
if self._parent is not None:
return self._parent.store_user(data)
# state parameter is artificial
return BaseUser(state=self, data=data) # type: ignore
def create_user(self, data):
def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> BaseUser:
# state parameter is artificial
return BaseUser(state=self, data=data) # type: ignore
@property
def http(self):
def http(self) -> Union[HTTPClient, _FriendlyHttpAttributeErrorHelper]:
if self._parent is not None:
return self._parent.http
@ -630,7 +649,7 @@ class _WebhookState:
# however, using it should result in a late-binding error.
return _FriendlyHttpAttributeErrorHelper()
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> Any:
if self._parent is not None:
return getattr(self._parent, attr)
@ -830,19 +849,24 @@ class BaseWebhook(Hashable):
'_state',
)
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None):
def __init__(
self,
data: WebhookPayload,
token: Optional[str] = None,
state: Optional[_State] = None,
) -> None:
self.auth_token: Optional[str] = token
self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(self, parent=state)
self._state: _State = state or _WebhookState(self, parent=state)
self._update(data)
def _update(self, data: WebhookPayload):
self.id = int(data['id'])
self.type = try_enum(WebhookType, int(data['type']))
self.channel_id = utils._get_as_snowflake(data, 'channel_id')
self.guild_id = utils._get_as_snowflake(data, 'guild_id')
self.name = data.get('name')
self._avatar = data.get('avatar')
self.token = data.get('token')
def _update(self, data: WebhookPayload) -> None:
self.id: int = int(data['id'])
self.type: WebhookType = try_enum(WebhookType, int(data['type']))
self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id')
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
self.name: Optional[str] = data.get('name')
self._avatar: Optional[str] = data.get('avatar')
self.token: Optional[str] = data.get('token')
user = data.get('user')
self.user: Optional[Union[BaseUser, User]] = None
@ -1010,11 +1034,17 @@ class Webhook(BaseWebhook):
__slots__: Tuple[str, ...] = ('session',)
def __init__(self, data: WebhookPayload, session: aiohttp.ClientSession, token: Optional[str] = None, state=None):
def __init__(
self,
data: WebhookPayload,
session: aiohttp.ClientSession,
token: Optional[str] = None,
state: Optional[_State] = None,
) -> None:
super().__init__(data, token, state)
self.session = session
self.session: aiohttp.ClientSession = session
def __repr__(self):
def __repr__(self) -> str:
return f'<Webhook id={self.id!r}>'
@property
@ -1023,7 +1053,7 @@ class Webhook(BaseWebhook):
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
@classmethod
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Self:
"""Creates a partial :class:`Webhook`.
Parameters
@ -1059,7 +1089,7 @@ class Webhook(BaseWebhook):
return cls(data, session, token=bot_token)
@classmethod
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Self:
"""Creates a partial :class:`Webhook` from a webhook URL.
.. versionchanged:: 2.0
@ -1102,7 +1132,7 @@ class Webhook(BaseWebhook):
return cls(data, session, token=bot_token) # type: ignore
@classmethod
def _as_follower(cls, data, *, channel, user) -> Webhook:
def _as_follower(cls, data, *, channel, user) -> Self:
name = f"{channel.guild} #{channel}"
feed: WebhookPayload = {
'id': data['webhook_id'],
@ -1118,8 +1148,8 @@ class Webhook(BaseWebhook):
return cls(feed, session=session, state=state, token=state.http.token)
@classmethod
def from_state(cls, data, state) -> Webhook:
session = state.http._HTTPClient__session
def from_state(cls, data: WebhookPayload, state: ConnectionState) -> Self:
session = state.http._HTTPClient__session # type: ignore
return cls(data, session=session, state=state, token=state.http.token)
async def fetch(self, *, prefer_auth: bool = True) -> Webhook:
@ -1168,7 +1198,7 @@ class Webhook(BaseWebhook):
return Webhook(data, self.session, token=self.auth_token, state=self._state)
async def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True):
async def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True) -> None:
"""|coro|
Deletes this Webhook.

62
discord/webhook/sync.py

@ -37,7 +37,7 @@ import time
import re
from urllib.parse import quote as urlquote
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, TypeVar, Type, overload
import weakref
from .. import utils
@ -56,36 +56,50 @@ __all__ = (
_log = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing_extensions import Self
from types import TracebackType
from ..file import File
from ..embeds import Embed
from ..mentions import AllowedMentions
from ..message import Attachment
from ..abc import Snowflake
from ..state import ConnectionState
from ..types.webhook import (
Webhook as WebhookPayload,
)
from ..abc import Snowflake
from ..types.message import (
Message as MessagePayload,
)
BE = TypeVar('BE', bound=BaseException)
try:
from requests import Session, Response
except ModuleNotFoundError:
pass
MISSING = utils.MISSING
MISSING: Any = utils.MISSING
class DeferredLock:
def __init__(self, lock: threading.Lock):
self.lock = lock
def __init__(self, lock: threading.Lock) -> None:
self.lock: threading.Lock = lock
self.delta: Optional[float] = None
def __enter__(self):
def __enter__(self) -> Self:
self.lock.acquire()
return self
def delay_by(self, delta: float) -> None:
self.delta = delta
def __exit__(self, type, value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BE]],
exc: Optional[BE],
traceback: Optional[TracebackType],
) -> None:
if self.delta:
time.sleep(self.delta)
self.lock.release()
@ -218,7 +232,7 @@ class WebhookAdapter:
token: Optional[str] = None,
session: Session,
reason: Optional[str] = None,
):
) -> None:
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, auth_token=token)
@ -229,7 +243,7 @@ class WebhookAdapter:
*,
session: Session,
reason: Optional[str] = None,
):
) -> None:
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason)
@ -241,7 +255,7 @@ class WebhookAdapter:
*,
session: Session,
reason: Optional[str] = None,
):
) -> WebhookPayload:
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, payload=payload, auth_token=token)
@ -253,7 +267,7 @@ class WebhookAdapter:
*,
session: Session,
reason: Optional[str] = None,
):
) -> WebhookPayload:
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason, payload=payload)
@ -268,7 +282,7 @@ class WebhookAdapter:
files: Optional[List[File]] = None,
thread_id: Optional[int] = None,
wait: bool = False,
):
) -> MessagePayload:
params = {'wait': int(wait)}
if thread_id:
params['thread_id'] = thread_id
@ -282,7 +296,7 @@ class WebhookAdapter:
message_id: int,
*,
session: Session,
):
) -> MessagePayload:
route = Route(
'GET',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -302,7 +316,7 @@ class WebhookAdapter:
payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None,
):
) -> MessagePayload:
route = Route(
'PATCH',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -319,7 +333,7 @@ class WebhookAdapter:
message_id: int,
*,
session: Session,
):
) -> None:
route = Route(
'DELETE',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -335,7 +349,7 @@ class WebhookAdapter:
token: str,
*,
session: Session,
):
) -> WebhookPayload:
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session=session, auth_token=token)
@ -345,7 +359,7 @@ class WebhookAdapter:
token: str,
*,
session: Session,
):
) -> WebhookPayload:
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session=session)
@ -569,11 +583,17 @@ class SyncWebhook(BaseWebhook):
__slots__: Tuple[str, ...] = ('session',)
def __init__(self, data: WebhookPayload, session: Session, token: Optional[str] = None, state=None):
def __init__(
self,
data: WebhookPayload,
session: Session,
token: Optional[str] = None,
state: Optional[Union[ConnectionState, _WebhookState]] = None,
) -> None:
super().__init__(data, token, state)
self.session = session
self.session: Session = session
def __repr__(self):
def __repr__(self) -> str:
return f'<Webhook id={self.id!r}>'
@property
@ -812,7 +832,7 @@ class SyncWebhook(BaseWebhook):
return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state)
def _create_message(self, data):
def _create_message(self, data: MessagePayload) -> SyncWebhookMessage:
state = _WebhookState(self, parent=self._state)
# state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore

2
discord/widget.py

@ -278,7 +278,7 @@ class Widget:
def __str__(self) -> str:
return self.json_url
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if isinstance(other, Widget):
return self.id == other.id
return False

Loading…
Cancel
Save