From 5aa696ccfa12c2c66c2bd9813b18d2c0cf32ef6a Mon Sep 17 00:00:00 2001 From: Stocker <44980366+StockerMC@users.noreply.github.com> Date: Sun, 13 Mar 2022 23:52:10 -0400 Subject: [PATCH] Fix typing issues and improve typing completeness across the library Co-authored-by: Danny Co-authored-by: Josh --- discord/__main__.py | 21 +- discord/abc.py | 56 ++--- discord/activity.py | 42 ++-- discord/app_commands/commands.py | 22 +- discord/app_commands/errors.py | 8 +- discord/app_commands/models.py | 87 ++++---- discord/app_commands/transformers.py | 4 +- discord/app_commands/tree.py | 53 +++-- discord/asset.py | 53 ++--- discord/audit_logs.py | 25 ++- discord/channel.py | 22 +- discord/client.py | 10 +- discord/colour.py | 13 +- discord/components.py | 10 +- discord/context_managers.py | 10 +- discord/embeds.py | 12 +- discord/emoji.py | 4 +- discord/enums.py | 44 ++-- discord/ext/commands/_types.py | 18 +- discord/ext/commands/bot.py | 88 +++++--- discord/ext/commands/cog.py | 16 +- discord/ext/commands/context.py | 15 +- discord/ext/commands/converter.py | 86 ++++---- discord/ext/commands/cooldowns.py | 2 +- discord/ext/commands/core.py | 115 ++++++----- discord/ext/commands/errors.py | 14 +- discord/ext/commands/flags.py | 18 +- discord/ext/commands/help.py | 293 ++++++++++++++++----------- discord/ext/commands/view.py | 37 ++-- discord/ext/tasks/__init__.py | 8 +- discord/file.py | 2 +- discord/flags.py | 20 +- discord/gateway.py | 30 +-- discord/guild.py | 5 +- discord/http.py | 22 +- discord/integrations.py | 17 +- discord/interactions.py | 7 +- discord/invite.py | 4 +- discord/member.py | 13 +- discord/mentions.py | 8 +- discord/message.py | 17 +- discord/opus.py | 5 +- discord/partial_emoji.py | 15 +- discord/permissions.py | 4 +- discord/player.py | 21 +- discord/reaction.py | 4 +- discord/role.py | 4 +- discord/scheduled_event.py | 6 +- discord/shard.py | 5 +- discord/stage_instance.py | 6 +- discord/state.py | 58 ++++-- discord/sticker.py | 2 +- discord/threads.py | 71 +++---- discord/types/activity.py | 1 + discord/types/widget.py | 5 + discord/ui/button.py | 17 +- discord/ui/item.py | 4 +- discord/ui/modal.py | 8 +- discord/ui/select.py | 18 +- discord/ui/view.py | 30 +-- discord/user.py | 6 +- discord/utils.py | 36 ++-- discord/voice_client.py | 10 +- discord/webhook/async_.py | 122 ++++++----- discord/webhook/sync.py | 62 ++++-- discord/widget.py | 2 +- 66 files changed, 1071 insertions(+), 802 deletions(-) diff --git a/discord/__main__.py b/discord/__main__.py index 570dcbb86..acf9601ed 100644 --- a/discord/__main__.py +++ b/discord/__main__.py @@ -23,7 +23,8 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations -from typing import Dict, Optional + +from typing import Optional, Tuple, Dict import argparse import sys @@ -35,7 +36,7 @@ import aiohttp import platform -def show_version(): +def show_version() -> None: entries = [] entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info)) @@ -52,7 +53,7 @@ def show_version(): print('\n'.join(entries)) -def core(parser, args): +def core(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None: if args.version: show_version() @@ -185,7 +186,7 @@ _base_table.update((chr(i), None) for i in range(32)) _translation_table = str.maketrans(_base_table) -def to_path(parser, name, *, replace_spaces=False): +def to_path(parser: argparse.ArgumentParser, name: str, *, replace_spaces: bool = False) -> Path: if isinstance(name, Path): return name @@ -223,7 +224,7 @@ def to_path(parser, name, *, replace_spaces=False): return Path(name) -def newbot(parser, args): +def newbot(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None: new_directory = to_path(parser, args.directory) / to_path(parser, args.name) # as a note exist_ok for Path is a 3.5+ only feature @@ -265,7 +266,7 @@ def newbot(parser, args): print('successfully made bot at', new_directory) -def newcog(parser, args): +def newcog(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None: cog_dir = to_path(parser, args.directory) try: cog_dir.mkdir(exist_ok=True) @@ -299,7 +300,7 @@ def newcog(parser, args): print('successfully made cog at', directory) -def add_newbot_args(subparser): +def add_newbot_args(subparser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None: parser = subparser.add_parser('newbot', help='creates a command bot project quickly') parser.set_defaults(func=newbot) @@ -310,7 +311,7 @@ def add_newbot_args(subparser): parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git') -def add_newcog_args(subparser): +def add_newcog_args(subparser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None: parser = subparser.add_parser('newcog', help='creates a new cog template quickly') parser.set_defaults(func=newcog) @@ -322,7 +323,7 @@ def add_newcog_args(subparser): parser.add_argument('--full', help='add all special methods as well', action='store_true') -def parse_args(): +def parse_args() -> Tuple[argparse.ArgumentParser, argparse.Namespace]: parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py') parser.add_argument('-v', '--version', action='store_true', help='shows the library version') parser.set_defaults(func=core) @@ -333,7 +334,7 @@ def parse_args(): return parser, parser.parse_args() -def main(): +def main() -> None: parser, args = parse_args() args.func(parser, args) diff --git a/discord/abc.py b/discord/abc.py index 33b852e74..92a052103 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -91,6 +91,9 @@ if TYPE_CHECKING: GuildChannel as GuildChannelPayload, OverwriteType, ) + from .types.snowflake import ( + SnowflakeList, + ) PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable] MessageableChannel = Union[PartialMessageableChannel, GroupChannel] @@ -708,7 +711,14 @@ class GuildChannel: ) -> None: ... - async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions): + async def set_permissions( + self, + target: Union[Member, Role], + *, + overwrite: Any = _undefined, + reason: Optional[str] = None, + **permissions: bool, + ) -> None: r"""|coro| Sets the channel specific permission overwrites for a target in the @@ -917,7 +927,7 @@ class GuildChannel: ) -> None: ... - async def move(self, **kwargs) -> None: + async def move(self, **kwargs: Any) -> None: """|coro| A rich interface to help move a channel relative to other channels. @@ -1248,22 +1258,22 @@ class Messageable: async def send( self, - content=None, + content: Optional[str] = None, *, - tts=False, - embed=None, - embeds=None, - file=None, - files=None, - stickers=None, - delete_after=None, - nonce=None, - allowed_mentions=None, - reference=None, - mention_author=None, - view=None, - suppress_embeds=False, - ): + tts: bool = False, + embed: Optional[Embed] = None, + embeds: Optional[List[Embed]] = None, + file: Optional[File] = None, + files: Optional[List[File]] = None, + stickers: Optional[Sequence[Union[GuildSticker, StickerItem]]] = None, + delete_after: Optional[float] = None, + nonce: Optional[Union[str, int]] = None, + allowed_mentions: Optional[AllowedMentions] = None, + reference: Optional[Union[Message, MessageReference, PartialMessage]] = None, + mention_author: Optional[bool] = None, + view: Optional[View] = None, + suppress_embeds: bool = False, + ) -> Message: """|coro| Sends a message to the destination with the content given. @@ -1368,17 +1378,17 @@ class Messageable: previous_allowed_mention = state.allowed_mentions if stickers is not None: - stickers = [sticker.id for sticker in stickers] + sticker_ids: SnowflakeList = [sticker.id for sticker in stickers] else: - stickers = MISSING + sticker_ids = MISSING if reference is not None: try: - reference = reference.to_message_reference_dict() + reference_dict = reference.to_message_reference_dict() except AttributeError: raise TypeError('reference parameter must be Message, MessageReference, or PartialMessage') from None else: - reference = MISSING + reference_dict = MISSING if view and not hasattr(view, '__discord_ui_view__'): raise TypeError(f'view parameter must be View not {view.__class__!r}') @@ -1399,10 +1409,10 @@ class Messageable: embeds=embeds if embeds is not None else MISSING, nonce=nonce, allowed_mentions=allowed_mentions, - message_reference=reference, + message_reference=reference_dict, previous_allowed_mentions=previous_allowed_mention, mention_author=mention_author, - stickers=stickers, + stickers=sticker_ids, view=view, flags=flags, ) as params: diff --git a/discord/activity.py b/discord/activity.py index 78362a84a..5f1318276 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -123,7 +123,7 @@ class BaseActivity: __slots__ = ('_created_at',) - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: self._created_at: Optional[float] = kwargs.pop('created_at', None) @property @@ -218,7 +218,7 @@ class Activity(BaseActivity): 'buttons', ) - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.state: Optional[str] = kwargs.pop('state', None) self.details: Optional[str] = kwargs.pop('details', None) @@ -363,7 +363,7 @@ class Game(BaseActivity): __slots__ = ('name', '_end', '_start') - def __init__(self, name: str, **extra): + def __init__(self, name: str, **extra: Any) -> None: super().__init__(**extra) self.name: str = name @@ -420,10 +420,10 @@ class Game(BaseActivity): } # fmt: on - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, Game) and other.name == self.name - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self) -> int: @@ -477,7 +477,7 @@ class Streaming(BaseActivity): __slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets') - def __init__(self, *, name: Optional[str], url: str, **extra: Any): + def __init__(self, *, name: Optional[str], url: str, **extra: Any) -> None: super().__init__(**extra) self.platform: Optional[str] = name self.name: Optional[str] = extra.pop('details', name) @@ -501,7 +501,7 @@ class Streaming(BaseActivity): return f'' @property - def twitch_name(self): + def twitch_name(self) -> Optional[str]: """Optional[:class:`str`]: If provided, the twitch name of the user streaming. This corresponds to the ``large_image`` key of the :attr:`Streaming.assets` @@ -528,10 +528,10 @@ class Streaming(BaseActivity): ret['details'] = self.details return ret - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, Streaming) and other.name == self.name and other.url == self.url - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self) -> int: @@ -563,14 +563,14 @@ class Spotify: __slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at') - def __init__(self, **data): + def __init__(self, **data: Any) -> None: self._state: str = data.pop('state', '') self._details: str = data.pop('details', '') - self._timestamps: Dict[str, int] = data.pop('timestamps', {}) + self._timestamps: ActivityTimestamps = data.pop('timestamps', {}) self._assets: ActivityAssets = data.pop('assets', {}) self._party: ActivityParty = data.pop('party', {}) - self._sync_id: str = data.pop('sync_id') - self._session_id: str = data.pop('session_id') + self._sync_id: str = data.pop('sync_id', '') + self._session_id: Optional[str] = data.pop('session_id') self._created_at: Optional[float] = data.pop('created_at', None) @property @@ -622,7 +622,7 @@ class Spotify: """:class:`str`: The activity's name. This will always return "Spotify".""" return 'Spotify' - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, Spotify) and other._session_id == self._session_id @@ -630,7 +630,7 @@ class Spotify: and other.start == self.start ) - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self) -> int: @@ -691,12 +691,14 @@ class Spotify: @property def start(self) -> datetime.datetime: """:class:`datetime.datetime`: When the user started playing this song in UTC.""" - return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc) + # the start key will be present here + return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc) # type: ignore @property def end(self) -> datetime.datetime: """:class:`datetime.datetime`: When the user will stop playing this song in UTC.""" - return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc) + # the end key will be present here + return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc) # type: ignore @property def duration(self) -> datetime.timedelta: @@ -742,7 +744,7 @@ class CustomActivity(BaseActivity): __slots__ = ('name', 'emoji', 'state') - def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any): + def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any) -> None: super().__init__(**extra) self.name: Optional[str] = name self.state: Optional[str] = extra.pop('state', None) @@ -786,10 +788,10 @@ class CustomActivity(BaseActivity): o['emoji'] = self.emoji.to_dict() return o - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self) -> int: diff --git a/discord/app_commands/commands.py b/discord/app_commands/commands.py index 660262c13..d1db76a7e 100644 --- a/discord/app_commands/commands.py +++ b/discord/app_commands/commands.py @@ -166,7 +166,7 @@ def _validate_auto_complete_callback( return callback -def _context_menu_annotation(annotation: Any, *, _none=NoneType) -> AppCommandType: +def _context_menu_annotation(annotation: Any, *, _none: type = NoneType) -> AppCommandType: if annotation is Message: return AppCommandType.message @@ -686,7 +686,7 @@ class Group: The parent group. ``None`` if there isn't one. """ - __discord_app_commands_group_children__: ClassVar[List[Union[Command, Group]]] = [] + __discord_app_commands_group_children__: ClassVar[List[Union[Command[Any, ..., Any], Group]]] = [] __discord_app_commands_skip_init_binding__: bool = False __discord_app_commands_group_name__: str = MISSING __discord_app_commands_group_description__: str = MISSING @@ -694,10 +694,12 @@ class Group: def __init_subclass__(cls, *, name: str = MISSING, description: str = MISSING) -> None: if not cls.__discord_app_commands_group_children__: - cls.__discord_app_commands_group_children__ = children = [ + children: List[Union[Command[Any, ..., Any], Group]] = [ member for member in cls.__dict__.values() if isinstance(member, (Group, Command)) and member.parent is None ] + cls.__discord_app_commands_group_children__ = children + found = set() for child in children: if child.name in found: @@ -796,15 +798,15 @@ class Group: """Optional[:class:`Group`]: The parent of this group.""" return self.parent - def _get_internal_command(self, name: str) -> Optional[Union[Command, Group]]: + def _get_internal_command(self, name: str) -> Optional[Union[Command[Any, ..., Any], Group]]: return self._children.get(name) @property - def commands(self) -> List[Union[Command, Group]]: + def commands(self) -> List[Union[Command[Any, ..., Any], Group]]: """List[Union[:class:`Command`, :class:`Group`]]: The commands that this group contains.""" return list(self._children.values()) - async def on_error(self, interaction: Interaction, command: Command, error: AppCommandError) -> None: + async def on_error(self, interaction: Interaction, command: Command[Any, ..., Any], error: AppCommandError) -> None: """|coro| A callback that is called when a child's command raises an :exc:`AppCommandError`. @@ -823,7 +825,7 @@ class Group: pass - def add_command(self, command: Union[Command, Group], /, *, override: bool = False): + def add_command(self, command: Union[Command[Any, ..., Any], Group], /, *, override: bool = False) -> None: """Adds a command or group to this group's internal list of commands. Parameters @@ -855,7 +857,7 @@ class Group: if len(self._children) > 25: raise ValueError('maximum number of child commands exceeded') - def remove_command(self, name: str, /) -> Optional[Union[Command, Group]]: + def remove_command(self, name: str, /) -> Optional[Union[Command[Any, ..., Any], Group]]: """Removes a command or group from the internal list of commands. Parameters @@ -872,7 +874,7 @@ class Group: self._children.pop(name, None) - def get_command(self, name: str, /) -> Optional[Union[Command, Group]]: + def get_command(self, name: str, /) -> Optional[Union[Command[Any, ..., Any], Group]]: """Retrieves a command or group from its name. Parameters @@ -1046,7 +1048,7 @@ def describe(**parameters: str) -> Callable[[T], T]: return decorator -def choices(**parameters: List[Choice]) -> Callable[[T], T]: +def choices(**parameters: List[Choice[ChoiceT]]) -> Callable[[T], T]: r"""Instructs the given parameters by their name to use the given choices for their choices. Example: diff --git a/discord/app_commands/errors.py b/discord/app_commands/errors.py index e4a379e09..73cda13ab 100644 --- a/discord/app_commands/errors.py +++ b/discord/app_commands/errors.py @@ -79,9 +79,9 @@ class CommandInvokeError(AppCommandError): The command that failed. """ - def __init__(self, command: Union[Command, ContextMenu], e: Exception) -> None: + def __init__(self, command: Union[Command[Any, ..., Any], ContextMenu], e: Exception) -> None: self.original: Exception = e - self.command: Union[Command, ContextMenu] = command + self.command: Union[Command[Any, ..., Any], ContextMenu] = command super().__init__(f'Command {command.name!r} raised an exception: {e.__class__.__name__}: {e}') @@ -191,8 +191,8 @@ class CommandSignatureMismatch(AppCommandError): The command that had the signature mismatch. """ - def __init__(self, command: Union[Command, ContextMenu, Group]): - self.command: Union[Command, ContextMenu, Group] = command + def __init__(self, command: Union[Command[Any, ..., Any], ContextMenu, Group]): + self.command: Union[Command[Any, ..., Any], ContextMenu, Group] = command msg = ( f'The signature for command {command.name!r} is different from the one provided by Discord. ' 'This can happen because either your code is out of date or you have not synced the ' diff --git a/discord/app_commands/models.py b/discord/app_commands/models.py index 8d0a23285..3a1e2bb35 100644 --- a/discord/app_commands/models.py +++ b/discord/app_commands/models.py @@ -58,7 +58,10 @@ if TYPE_CHECKING: PartialChannel, PartialThread, ) - from ..types.threads import ThreadMetadata + from ..types.threads import ( + ThreadMetadata, + ThreadArchiveDuration, + ) from ..state import ConnectionState from ..guild import GuildChannel, Guild from ..channel import TextChannel @@ -117,17 +120,19 @@ class AppCommand(Hashable): '_state', ) - def __init__(self, *, data: ApplicationCommandPayload, state=None): - self._state = state + def __init__(self, *, data: ApplicationCommandPayload, state: Optional[ConnectionState] = None) -> None: + self._state: Optional[ConnectionState] = state self._from_data(data) - def _from_data(self, data: ApplicationCommandPayload): + def _from_data(self, data: ApplicationCommandPayload) -> None: self.id: int = int(data['id']) self.application_id: int = int(data['application_id']) self.name: str = data['name'] self.description: str = data['description'] self.type: AppCommandType = try_enum(AppCommandType, data.get('type', 1)) - self.options = [app_command_option_factory(data=d, parent=self, state=self._state) for d in data.get('options', [])] + self.options: List[Union[Argument, AppCommandGroup]] = [ + app_command_option_factory(data=d, parent=self, state=self._state) for d in data.get('options', []) + ] def to_dict(self) -> ApplicationCommandPayload: return { @@ -262,12 +267,12 @@ class AppCommandChannel(Hashable): data: PartialChannel, guild_id: int, ): - self._state = state - self.guild_id = guild_id - self.id = int(data['id']) - self.type = try_enum(ChannelType, data['type']) - self.name = data['name'] - self.permissions = Permissions(int(data['permissions'])) + self._state: ConnectionState = state + self.guild_id: int = guild_id + self.id: int = int(data['id']) + self.type: ChannelType = try_enum(ChannelType, data['type']) + self.name: str = data['name'] + self.permissions: Permissions = Permissions(int(data['permissions'])) def __str__(self) -> str: return self.name @@ -405,13 +410,13 @@ class AppCommandThread(Hashable): data: PartialThread, guild_id: int, ): - self._state = state - self.guild_id = guild_id - self.id = int(data['id']) - self.parent_id = int(data['parent_id']) - self.type = try_enum(ChannelType, data['type']) - self.name = data['name'] - self.permissions = Permissions(int(data['permissions'])) + self._state: ConnectionState = state + self.guild_id: int = guild_id + self.id: int = int(data['id']) + self.parent_id: int = int(data['parent_id']) + self.type: ChannelType = try_enum(ChannelType, data['type']) + self.name: str = data['name'] + self.permissions: Permissions = Permissions(int(data['permissions'])) self._unroll_metadata(data['thread_metadata']) def __str__(self) -> str: @@ -425,14 +430,14 @@ class AppCommandThread(Hashable): """Optional[:class:`~discord.Guild`]: The channel's guild, from cache, if found.""" return self._state._get_guild(self.guild_id) - def _unroll_metadata(self, data: ThreadMetadata): - self.archived = data['archived'] - self.archiver_id = _get_as_snowflake(data, 'archiver_id') - self.auto_archive_duration = data['auto_archive_duration'] - self.archive_timestamp = parse_time(data['archive_timestamp']) - self.locked = data.get('locked', False) - self.invitable = data.get('invitable', True) - self._created_at = parse_time(data.get('create_timestamp')) + def _unroll_metadata(self, data: ThreadMetadata) -> None: + self.archived: bool = data['archived'] + self.archiver_id: Optional[int] = _get_as_snowflake(data, 'archiver_id') + self.auto_archive_duration: ThreadArchiveDuration = data['auto_archive_duration'] + self.archive_timestamp: datetime = parse_time(data['archive_timestamp']) + self.locked: bool = data.get('locked', False) + self.invitable: bool = data.get('invitable', True) + self._created_at: Optional[datetime] = parse_time(data.get('create_timestamp')) @property def parent(self) -> Optional[TextChannel]: @@ -522,20 +527,24 @@ class Argument: '_state', ) - def __init__(self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state=None): - self._state = state - self.parent = parent + def __init__( + self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state: Optional[ConnectionState] = None + ) -> None: + self._state: Optional[ConnectionState] = state + self.parent: ApplicationCommandParent = parent self._from_data(data) def __repr__(self) -> str: return f'<{self.__class__.__name__} name={self.name!r} type={self.type!r} required={self.required}>' - def _from_data(self, data: ApplicationCommandOption): + def _from_data(self, data: ApplicationCommandOption) -> None: self.type: AppCommandOptionType = try_enum(AppCommandOptionType, data['type']) self.name: str = data['name'] self.description: str = data['description'] self.required: bool = data.get('required', False) - self.choices: List[Choice] = [Choice(name=d['name'], value=d['value']) for d in data.get('choices', [])] + self.choices: List[Choice[Union[int, float, str]]] = [ + Choice(name=d['name'], value=d['value']) for d in data.get('choices', []) + ] def to_dict(self) -> ApplicationCommandOption: return { @@ -582,20 +591,24 @@ class AppCommandGroup: '_state', ) - def __init__(self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state=None): - self.parent = parent - self._state = state + def __init__( + self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state: Optional[ConnectionState] = None + ) -> None: + self.parent: ApplicationCommandParent = parent + self._state: Optional[ConnectionState] = state self._from_data(data) def __repr__(self) -> str: return f'<{self.__class__.__name__} name={self.name!r} type={self.type!r} required={self.required}>' - def _from_data(self, data: ApplicationCommandOption): + def _from_data(self, data: ApplicationCommandOption) -> None: self.type: AppCommandOptionType = try_enum(AppCommandOptionType, data['type']) self.name: str = data['name'] self.description: str = data['description'] self.required: bool = data.get('required', False) - self.choices: List[Choice] = [Choice(name=d['name'], value=d['value']) for d in data.get('choices', [])] + self.choices: List[Choice[Union[int, float, str]]] = [ + Choice(name=d['name'], value=d['value']) for d in data.get('choices', []) + ] self.arguments: List[Argument] = [ Argument(parent=self, state=self._state, data=d) for d in data.get('options', []) @@ -614,7 +627,7 @@ class AppCommandGroup: def app_command_option_factory( - parent: ApplicationCommandParent, data: ApplicationCommandOption, *, state=None + parent: ApplicationCommandParent, data: ApplicationCommandOption, *, state: Optional[ConnectionState] = None ) -> Union[Argument, AppCommandGroup]: if is_app_command_argument_type(data['type']): return Argument(parent=parent, data=data, state=state) diff --git a/discord/app_commands/transformers.py b/discord/app_commands/transformers.py index 1f74a6b7d..b6e50238d 100644 --- a/discord/app_commands/transformers.py +++ b/discord/app_commands/transformers.py @@ -95,7 +95,7 @@ class CommandParameter: description: str = MISSING required: bool = MISSING default: Any = MISSING - choices: List[Choice] = MISSING + choices: List[Choice[Union[str, int, float]]] = MISSING type: AppCommandOptionType = MISSING channel_types: List[ChannelType] = MISSING min_value: Optional[Union[int, float]] = None @@ -549,7 +549,7 @@ ALLOWED_DEFAULTS: Dict[AppCommandOptionType, Tuple[Type[Any], ...]] = { def get_supported_annotation( annotation: Any, *, - _none=NoneType, + _none: type = NoneType, _mapping: Dict[Any, Type[Transformer]] = BUILT_IN_TRANSFORMERS, ) -> Tuple[Any, Any]: """Returns an appropriate, yet supported, annotation along with an optional default value. diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index 371a70e1d..5aa0aef48 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -26,7 +26,22 @@ from __future__ import annotations import inspect import sys import traceback -from typing import Any, Callable, Dict, Generic, List, Literal, Optional, TYPE_CHECKING, Set, Tuple, TypeVar, Union, overload + +from typing import ( + Any, + TYPE_CHECKING, + Callable, + Dict, + Generic, + List, + Literal, + Optional, + Set, + Tuple, + TypeVar, + Union, + overload, +) from collections import Counter @@ -194,13 +209,13 @@ class CommandTree(Generic[ClientT]): def add_command( self, - command: Union[Command, ContextMenu, Group], + command: Union[Command[Any, ..., Any], ContextMenu, Group], /, *, guild: Optional[Snowflake] = MISSING, guilds: List[Snowflake] = MISSING, override: bool = False, - ): + ) -> None: """Adds an application command to the tree. This only adds the command locally -- in order to sync the commands @@ -317,7 +332,7 @@ class CommandTree(Generic[ClientT]): *, guild: Optional[Snowflake] = ..., type: Literal[AppCommandType.chat_input] = ..., - ) -> Optional[Union[Command, Group]]: + ) -> Optional[Union[Command[Any, ..., Any], Group]]: ... @overload @@ -328,7 +343,7 @@ class CommandTree(Generic[ClientT]): *, guild: Optional[Snowflake] = ..., type: AppCommandType = ..., - ) -> Optional[Union[Command, ContextMenu, Group]]: + ) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]: ... def remove_command( @@ -338,7 +353,7 @@ class CommandTree(Generic[ClientT]): *, guild: Optional[Snowflake] = None, type: AppCommandType = AppCommandType.chat_input, - ) -> Optional[Union[Command, ContextMenu, Group]]: + ) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]: """Removes an application command from the tree. This only removes the command locally -- in order to sync the commands @@ -396,7 +411,7 @@ class CommandTree(Generic[ClientT]): *, guild: Optional[Snowflake] = ..., type: Literal[AppCommandType.chat_input] = ..., - ) -> Optional[Union[Command, Group]]: + ) -> Optional[Union[Command[Any, ..., Any], Group]]: ... @overload @@ -407,7 +422,7 @@ class CommandTree(Generic[ClientT]): *, guild: Optional[Snowflake] = ..., type: AppCommandType = ..., - ) -> Optional[Union[Command, ContextMenu, Group]]: + ) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]: ... def get_command( @@ -417,7 +432,7 @@ class CommandTree(Generic[ClientT]): *, guild: Optional[Snowflake] = None, type: AppCommandType = AppCommandType.chat_input, - ) -> Optional[Union[Command, ContextMenu, Group]]: + ) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]: """Gets a application command from the tree. Parameters @@ -468,7 +483,7 @@ class CommandTree(Generic[ClientT]): *, guild: Optional[Snowflake] = ..., type: Literal[AppCommandType.chat_input] = ..., - ) -> List[Union[Command, Group]]: + ) -> List[Union[Command[Any, ..., Any], Group]]: ... @overload @@ -477,7 +492,7 @@ class CommandTree(Generic[ClientT]): *, guild: Optional[Snowflake] = ..., type: AppCommandType = ..., - ) -> Union[List[Union[Command, Group]], List[ContextMenu]]: + ) -> Union[List[Union[Command[Any, ..., Any], Group]], List[ContextMenu]]: ... def get_commands( @@ -485,7 +500,7 @@ class CommandTree(Generic[ClientT]): *, guild: Optional[Snowflake] = None, type: AppCommandType = AppCommandType.chat_input, - ) -> Union[List[Union[Command, Group]], List[ContextMenu]]: + ) -> Union[List[Union[Command[Any, ..., Any], Group]], List[ContextMenu]]: """Gets all application commands from the tree. Parameters @@ -518,9 +533,11 @@ class CommandTree(Generic[ClientT]): value = type.value return [command for ((_, g, t), command) in self._context_menus.items() if g == guild_id and t == value] - def _get_all_commands(self, *, guild: Optional[Snowflake] = None) -> List[Union[Command, Group, ContextMenu]]: + def _get_all_commands( + self, *, guild: Optional[Snowflake] = None + ) -> List[Union[Command[Any, ..., Any], Group, ContextMenu]]: if guild is None: - base: List[Union[Command, Group, ContextMenu]] = list(self._global_commands.values()) + base: List[Union[Command[Any, ..., Any], Group, ContextMenu]] = list(self._global_commands.values()) base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g is None) return base else: @@ -530,7 +547,7 @@ class CommandTree(Generic[ClientT]): guild_id = guild.id return [cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id] else: - base: List[Union[Command, Group, ContextMenu]] = list(commands.values()) + base: List[Union[Command[Any, ..., Any], Group, ContextMenu]] = list(commands.values()) guild_id = guild.id base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id) return base @@ -564,7 +581,7 @@ class CommandTree(Generic[ClientT]): async def on_error( self, interaction: Interaction, - command: Optional[Union[ContextMenu, Command]], + command: Optional[Union[ContextMenu, Command[Any, ..., Any]]], error: AppCommandError, ) -> None: """|coro| @@ -742,7 +759,7 @@ class CommandTree(Generic[ClientT]): self.client.loop.create_task(wrapper(), name='CommandTree-invoker') - async def _call_context_menu(self, interaction: Interaction, data: ApplicationCommandInteractionData, type: int): + async def _call_context_menu(self, interaction: Interaction, data: ApplicationCommandInteractionData, type: int) -> None: name = data['name'] guild_id = _get_as_snowflake(data, 'guild_id') ctx_menu = self._context_menus.get((name, guild_id, type)) @@ -770,7 +787,7 @@ class CommandTree(Generic[ClientT]): except AppCommandError as e: await self.on_error(interaction, ctx_menu, e) - async def call(self, interaction: Interaction): + async def call(self, interaction: Interaction) -> None: """|coro| Given an :class:`~discord.Interaction`, calls the matching diff --git a/discord/asset.py b/discord/asset.py index 67ac7dceb..f296693e3 100644 --- a/discord/asset.py +++ b/discord/asset.py @@ -39,6 +39,13 @@ __all__ = ( # fmt: on if TYPE_CHECKING: + from typing_extensions import Self + + from .state import ConnectionState + from .webhook.async_ import _WebhookState + + _State = Union[ConnectionState, _WebhookState] + ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png'] ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif'] @@ -77,7 +84,7 @@ class AssetMixin: return await self._state.http.get_from_cdn(self.url) - async def save(self, fp: Union[str, bytes, os.PathLike, io.BufferedIOBase], *, seek_begin: bool = True) -> int: + async def save(self, fp: Union[str, bytes, os.PathLike[Any], io.BufferedIOBase], *, seek_begin: bool = True) -> int: """|coro| Saves this asset into a file-like object. @@ -153,14 +160,14 @@ class Asset(AssetMixin): BASE = 'https://cdn.discordapp.com' - def __init__(self, state, *, url: str, key: str, animated: bool = False): - self._state = state - self._url = url - self._animated = animated - self._key = key + def __init__(self, state: _State, *, url: str, key: str, animated: bool = False) -> None: + self._state: _State = state + self._url: str = url + self._animated: bool = animated + self._key: str = key @classmethod - def _from_default_avatar(cls, state, index: int) -> Asset: + def _from_default_avatar(cls, state: _State, index: int) -> Self: return cls( state, url=f'{cls.BASE}/embed/avatars/{index}.png', @@ -169,7 +176,7 @@ class Asset(AssetMixin): ) @classmethod - def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: + def _from_avatar(cls, state: _State, user_id: int, avatar: str) -> Self: animated = avatar.startswith('a_') format = 'gif' if animated else 'png' return cls( @@ -180,7 +187,7 @@ class Asset(AssetMixin): ) @classmethod - def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset: + def _from_guild_avatar(cls, state: _State, guild_id: int, member_id: int, avatar: str) -> Self: animated = avatar.startswith('a_') format = 'gif' if animated else 'png' return cls( @@ -191,7 +198,7 @@ class Asset(AssetMixin): ) @classmethod - def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset: + def _from_icon(cls, state: _State, object_id: int, icon_hash: str, path: str) -> Self: return cls( state, url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024', @@ -200,7 +207,7 @@ class Asset(AssetMixin): ) @classmethod - def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset: + def _from_cover_image(cls, state: _State, object_id: int, cover_image_hash: str) -> Self: return cls( state, url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024', @@ -209,7 +216,7 @@ class Asset(AssetMixin): ) @classmethod - def _from_scheduled_event_cover_image(cls, state, scheduled_event_id: int, cover_image_hash: str) -> Asset: + def _from_scheduled_event_cover_image(cls, state: _State, scheduled_event_id: int, cover_image_hash: str) -> Self: return cls( state, url=f'{cls.BASE}/guild-events/{scheduled_event_id}/{cover_image_hash}.png?size=1024', @@ -218,7 +225,7 @@ class Asset(AssetMixin): ) @classmethod - def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset: + def _from_guild_image(cls, state: _State, guild_id: int, image: str, path: str) -> Self: animated = image.startswith('a_') format = 'gif' if animated else 'png' return cls( @@ -229,7 +236,7 @@ class Asset(AssetMixin): ) @classmethod - def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset: + def _from_guild_icon(cls, state: _State, guild_id: int, icon_hash: str) -> Self: animated = icon_hash.startswith('a_') format = 'gif' if animated else 'png' return cls( @@ -240,7 +247,7 @@ class Asset(AssetMixin): ) @classmethod - def _from_sticker_banner(cls, state, banner: int) -> Asset: + def _from_sticker_banner(cls, state: _State, banner: int) -> Self: return cls( state, url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png', @@ -249,7 +256,7 @@ class Asset(AssetMixin): ) @classmethod - def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: + def _from_user_banner(cls, state: _State, user_id: int, banner_hash: str) -> Self: animated = banner_hash.startswith('a_') format = 'gif' if animated else 'png' return cls( @@ -265,14 +272,14 @@ class Asset(AssetMixin): def __len__(self) -> int: return len(self._url) - def __repr__(self): + def __repr__(self) -> str: shorten = self._url.replace(self.BASE, '') return f'' - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, Asset) and self._url == other._url - def __hash__(self): + def __hash__(self) -> int: return hash(self._url) @property @@ -295,7 +302,7 @@ class Asset(AssetMixin): size: int = MISSING, format: ValidAssetFormatTypes = MISSING, static_format: ValidStaticFormatTypes = MISSING, - ) -> Asset: + ) -> Self: """Returns a new asset with the passed components replaced. .. versionchanged:: 2.0 @@ -350,7 +357,7 @@ class Asset(AssetMixin): url = str(url) return Asset(state=self._state, url=url, key=self._key, animated=self._animated) - def with_size(self, size: int, /) -> Asset: + def with_size(self, size: int, /) -> Self: """Returns a new asset with the specified size. .. versionchanged:: 2.0 @@ -378,7 +385,7 @@ class Asset(AssetMixin): url = str(yarl.URL(self._url).with_query(size=size)) return Asset(state=self._state, url=url, key=self._key, animated=self._animated) - def with_format(self, format: ValidAssetFormatTypes, /) -> Asset: + def with_format(self, format: ValidAssetFormatTypes, /) -> Self: """Returns a new asset with the specified format. .. versionchanged:: 2.0 @@ -413,7 +420,7 @@ class Asset(AssetMixin): url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string)) return Asset(state=self._state, url=url, key=self._key, animated=self._animated) - def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset: + def with_static_format(self, format: ValidStaticFormatTypes, /) -> Self: """Returns a new asset with the specified static format. This only changes the format if the underlying asset is diff --git a/discord/audit_logs.py b/discord/audit_logs.py index 04ccc2ff4..3ebf49efc 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -50,12 +50,12 @@ if TYPE_CHECKING: from .member import Member from .role import Role from .scheduled_event import ScheduledEvent + from .state import ConnectionState from .types.audit_log import ( AuditLogChange as AuditLogChangePayload, AuditLogEntry as AuditLogEntryPayload, ) from .types.channel import ( - PartialChannel as PartialChannelPayload, PermissionOverwrite as PermissionOverwritePayload, ) from .types.invite import Invite as InvitePayload @@ -242,8 +242,8 @@ class AuditLogChanges: # fmt: on def __init__(self, entry: AuditLogEntry, data: List[AuditLogChangePayload]): - self.before = AuditLogDiff() - self.after = AuditLogDiff() + self.before: AuditLogDiff = AuditLogDiff() + self.after: AuditLogDiff = AuditLogDiff() for elem in data: attr = elem['key'] @@ -390,17 +390,17 @@ class AuditLogEntry(Hashable): """ def __init__(self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild): - self._state = guild._state - self.guild = guild - self._users = users + self._state: ConnectionState = guild._state + self.guild: Guild = guild + self._users: Dict[int, User] = users self._from_data(data) def _from_data(self, data: AuditLogEntryPayload) -> None: - self.action = enums.try_enum(enums.AuditLogAction, data['action_type']) - self.id = int(data['id']) + self.action: enums.AuditLogAction = enums.try_enum(enums.AuditLogAction, data['action_type']) + self.id: int = int(data['id']) # this key is technically not usually present - self.reason = data.get('reason') + self.reason: Optional[str] = data.get('reason') extra = data.get('options') # fmt: off @@ -464,10 +464,13 @@ class AuditLogEntry(Hashable): self._changes = data.get('changes', []) user_id = utils._get_as_snowflake(data, 'user_id') - self.user = user_id and self._get_member(user_id) + self.user: Optional[Union[User, Member]] = self._get_member(user_id) self._target_id = utils._get_as_snowflake(data, 'target_id') - def _get_member(self, user_id: int) -> Union[Member, User, None]: + def _get_member(self, user_id: Optional[int]) -> Union[Member, User, None]: + if user_id is None: + return None + return self.guild.get_member(user_id) or self._users.get(user_id) def __repr__(self) -> str: diff --git a/discord/channel.py b/discord/channel.py index 39fefa904..c2d9ba1f6 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -198,7 +198,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id') self._fill_overwrites(data) - async def _get_channel(self): + async def _get_channel(self) -> Self: return self @property @@ -283,7 +283,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): async def edit(self) -> Optional[TextChannel]: ... - async def edit(self, *, reason=None, **options): + async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[TextChannel]: """|coro| Edits the channel. @@ -908,7 +908,7 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha return self.guild.id, self.id def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None: - self.guild = guild + self.guild: Guild = guild self.name: str = data['name'] self.rtc_region: Optional[str] = data.get('rtc_region') self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1)) @@ -1076,7 +1076,7 @@ class VoiceChannel(VocalGuildChannel): async def edit(self) -> Optional[VoiceChannel]: ... - async def edit(self, *, reason=None, **options): + async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[VoiceChannel]: """|coro| Edits the channel. @@ -1220,7 +1220,7 @@ class StageChannel(VocalGuildChannel): def _update(self, guild: Guild, data: StageChannelPayload) -> None: super()._update(guild, data) - self.topic = data.get('topic') + self.topic: Optional[str] = data.get('topic') @property def requesting_to_speak(self) -> List[Member]: @@ -1361,7 +1361,7 @@ class StageChannel(VocalGuildChannel): async def edit(self) -> Optional[StageChannel]: ... - async def edit(self, *, reason=None, **options): + async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[StageChannel]: """|coro| Edits the channel. @@ -1522,7 +1522,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): async def edit(self) -> Optional[CategoryChannel]: ... - async def edit(self, *, reason=None, **options): + async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[CategoryChannel]: """|coro| Edits the channel. @@ -1578,7 +1578,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore @utils.copy_doc(discord.abc.GuildChannel.move) - async def move(self, **kwargs): + async def move(self, **kwargs: Any) -> None: kwargs.pop('category', None) await super().move(**kwargs) @@ -1772,7 +1772,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable): async def edit(self) -> Optional[StoreChannel]: ... - async def edit(self, *, reason=None, **options): + async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[StoreChannel]: """|coro| Edits the channel. @@ -1874,7 +1874,7 @@ class DMChannel(discord.abc.Messageable, Hashable): self.me: ClientUser = me self.id: int = int(data['id']) - async def _get_channel(self): + async def _get_channel(self) -> Self: return self def __str__(self) -> str: @@ -2026,7 +2026,7 @@ class GroupChannel(discord.abc.Messageable, Hashable): else: self.owner = utils.find(lambda u: u.id == self.owner_id, self.recipients) - async def _get_channel(self): + async def _get_channel(self) -> Self: return self def __str__(self) -> str: diff --git a/discord/client.py b/discord/client.py index fc3bea864..6b1835646 100644 --- a/discord/client.py +++ b/discord/client.py @@ -196,11 +196,11 @@ class Client: unsync_clock: bool = options.pop('assume_unsync_clock', True) self.http: HTTPClient = HTTPClient(self.loop, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock) - self._handlers: Dict[str, Callable] = { + self._handlers: Dict[str, Callable[..., None]] = { 'ready': self._handle_ready, } - self._hooks: Dict[str, Callable] = { + self._hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = { 'before_identify': self._call_before_identify_hook, } @@ -698,7 +698,7 @@ class Client: raise TypeError('activity must derive from BaseActivity.') @property - def status(self): + def status(self) -> Status: """:class:`.Status`: The status being used upon logging on to Discord. @@ -709,7 +709,7 @@ class Client: return Status.online @status.setter - def status(self, value): + def status(self, value: Status) -> None: if value is Status.offline: self._connection._status = 'invisible' elif isinstance(value, Status): @@ -1077,7 +1077,7 @@ class Client: *, activity: Optional[BaseActivity] = None, status: Optional[Status] = None, - ): + ) -> None: """|coro| Changes the client's presence. diff --git a/discord/colour.py b/discord/colour.py index f204d08ee..5308cb74f 100644 --- a/discord/colour.py +++ b/discord/colour.py @@ -32,7 +32,6 @@ from typing import ( Callable, Optional, Tuple, - Type, Union, ) @@ -90,10 +89,10 @@ class Colour: def _get_byte(self, byte: int) -> int: return (self.value >> (8 * byte)) & 0xFF - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, Colour) and self.value == other.value - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __str__(self) -> str: @@ -265,28 +264,28 @@ class Colour: """A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``.""" return cls(0x95A5A6) - lighter_gray: Callable[[Type[Self]], Self] = lighter_grey + lighter_gray = lighter_grey @classmethod def dark_grey(cls) -> Self: """A factory method that returns a :class:`Colour` with a value of ``0x607d8b``.""" return cls(0x607D8B) - dark_gray: Callable[[Type[Self]], Self] = dark_grey + dark_gray = dark_grey @classmethod def light_grey(cls) -> Self: """A factory method that returns a :class:`Colour` with a value of ``0x979c9f``.""" return cls(0x979C9F) - light_gray: Callable[[Type[Self]], Self] = light_grey + light_gray = light_grey @classmethod def darker_grey(cls) -> Self: """A factory method that returns a :class:`Colour` with a value of ``0x546e7a``.""" return cls(0x546E7A) - darker_gray: Callable[[Type[Self]], Self] = darker_grey + darker_gray = darker_grey @classmethod def og_blurple(cls) -> Self: diff --git a/discord/components.py b/discord/components.py index 37559f8b1..86c73ae6b 100644 --- a/discord/components.py +++ b/discord/components.py @@ -310,9 +310,9 @@ class SelectOption: emoji: Optional[Union[str, Emoji, PartialEmoji]] = None, default: bool = False, ) -> None: - self.label = label - self.value = label if value is MISSING else value - self.description = description + self.label: str = label + self.value: str = label if value is MISSING else value + self.description: Optional[str] = description if emoji is not None: if isinstance(emoji, str): @@ -322,8 +322,8 @@ class SelectOption: else: raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}') - self.emoji = emoji - self.default = default + self.emoji: Optional[Union[str, Emoji, PartialEmoji]] = emoji + self.default: bool = default def __repr__(self) -> str: return ( diff --git a/discord/context_managers.py b/discord/context_managers.py index e6aa67901..ce8e73d97 100644 --- a/discord/context_managers.py +++ b/discord/context_managers.py @@ -25,13 +25,15 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING, Optional, Type, TypeVar if TYPE_CHECKING: from .abc import Messageable from types import TracebackType + BE = TypeVar('BE', bound=BaseException) + # fmt: off __all__ = ( 'Typing', @@ -67,13 +69,13 @@ class Typing: async def __aenter__(self) -> None: self._channel = channel = await self.messageable._get_channel() await channel._state.http.send_typing(channel.id) - self.task: asyncio.Task = self.loop.create_task(self.do_typing()) + self.task: asyncio.Task[None] = self.loop.create_task(self.do_typing()) self.task.add_done_callback(_typing_done_callback) async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], + exc_type: Optional[Type[BE]], + exc: Optional[BE], traceback: Optional[TracebackType], ) -> None: self.task.cancel() diff --git a/discord/embeds.py b/discord/embeds.py index f699da4b3..cb710d6e4 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -189,10 +189,10 @@ class Embed: ): self.colour = colour if colour is not EmptyEmbed else color - self.title = title - self.type = type - self.url = url - self.description = description + self.title: MaybeEmpty[str] = title + self.type: EmbedType = type + self.url: MaybeEmpty[str] = url + self.description: MaybeEmpty[str] = description if self.title is not EmptyEmbed: self.title = str(self.title) @@ -311,7 +311,7 @@ class Embed: return getattr(self, '_colour', EmptyEmbed) @colour.setter - def colour(self, value: Union[int, Colour, _EmptyEmbed]): + def colour(self, value: Union[int, Colour, _EmptyEmbed]) -> None: if isinstance(value, (Colour, _EmptyEmbed)): self._colour = value elif isinstance(value, int): @@ -326,7 +326,7 @@ class Embed: return getattr(self, '_timestamp', EmptyEmbed) @timestamp.setter - def timestamp(self, value: MaybeEmpty[datetime.datetime]): + def timestamp(self, value: MaybeEmpty[datetime.datetime]) -> None: if isinstance(value, datetime.datetime): if value.tzinfo is None: value = value.astimezone() diff --git a/discord/emoji.py b/discord/emoji.py index 35ceef532..9d2554593 100644 --- a/discord/emoji.py +++ b/discord/emoji.py @@ -142,10 +142,10 @@ class Emoji(_EmojiTag, AssetMixin): def __repr__(self) -> str: return f'' - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, _EmojiTag) and self.id == other.id - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self) -> int: diff --git a/discord/enums.py b/discord/enums.py index 25882cc1a..070901cac 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -25,7 +25,7 @@ from __future__ import annotations import types from collections import namedtuple -from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar +from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Iterator, Mapping __all__ = ( 'Enum', @@ -131,38 +131,38 @@ class EnumMeta(type): value_cls._actual_enum_cls_ = actual_cls # type: ignore - Runtime attribute isn't understood return actual_cls - def __iter__(cls): + def __iter__(cls) -> Iterator[Any]: return (cls._enum_member_map_[name] for name in cls._enum_member_names_) - def __reversed__(cls): + def __reversed__(cls) -> Iterator[Any]: return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_)) - def __len__(cls): + def __len__(cls) -> int: return len(cls._enum_member_names_) - def __repr__(cls): + def __repr__(cls) -> str: return f'' @property - def __members__(cls): + def __members__(cls) -> Mapping[str, Any]: return types.MappingProxyType(cls._enum_member_map_) - def __call__(cls, value): + def __call__(cls, value: str) -> Any: try: return cls._enum_value_map_[value] except (KeyError, TypeError): raise ValueError(f"{value!r} is not a valid {cls.__name__}") - def __getitem__(cls, key): + def __getitem__(cls, key: str) -> Any: return cls._enum_member_map_[key] - def __setattr__(cls, name, value): + def __setattr__(cls, name: str, value: Any) -> None: raise TypeError('Enums are immutable.') - def __delattr__(cls, attr): + def __delattr__(cls, attr: str) -> None: raise TypeError('Enums are immutable') - def __instancecheck__(self, instance): + def __instancecheck__(self, instance: Any) -> bool: # isinstance(x, Y) # -> __instancecheck__(Y, x) try: @@ -197,7 +197,7 @@ class ChannelType(Enum): private_thread = 12 stage_voice = 13 - def __str__(self): + def __str__(self) -> str: return self.name @@ -233,10 +233,10 @@ class SpeakingState(Enum): soundshare = 2 priority = 4 - def __str__(self): + def __str__(self) -> str: return self.name - def __int__(self): + def __int__(self) -> int: return self.value @@ -247,7 +247,7 @@ class VerificationLevel(Enum, comparable=True): high = 3 highest = 4 - def __str__(self): + def __str__(self) -> str: return self.name @@ -256,7 +256,7 @@ class ContentFilter(Enum, comparable=True): no_role = 1 all_members = 2 - def __str__(self): + def __str__(self) -> str: return self.name @@ -268,7 +268,7 @@ class Status(Enum): do_not_disturb = 'dnd' invisible = 'invisible' - def __str__(self): + def __str__(self) -> str: return self.value @@ -280,7 +280,7 @@ class DefaultAvatar(Enum): orange = 3 red = 4 - def __str__(self): + def __str__(self) -> str: return self.name @@ -467,7 +467,7 @@ class ActivityType(Enum): custom = 4 competing = 5 - def __int__(self): + def __int__(self) -> int: return self.value @@ -542,7 +542,7 @@ class VideoQualityMode(Enum): auto = 1 full = 2 - def __int__(self): + def __int__(self) -> int: return self.value @@ -552,7 +552,7 @@ class ComponentType(Enum): select = 3 text_input = 4 - def __int__(self): + def __int__(self) -> int: return self.value @@ -571,7 +571,7 @@ class ButtonStyle(Enum): red = 4 url = 5 - def __int__(self): + def __int__(self) -> int: return self.value diff --git a/discord/ext/commands/_types.py b/discord/ext/commands/_types.py index 5f7192f0a..2907d8f69 100644 --- a/discord/ext/commands/_types.py +++ b/discord/ext/commands/_types.py @@ -23,21 +23,35 @@ DEALINGS IN THE SOFTWARE. """ -from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union +from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple +T = TypeVar('T') + if TYPE_CHECKING: + from typing_extensions import ParamSpec + + from .bot import Bot, AutoShardedBot from .context import Context from .cog import Cog from .errors import CommandError -T = TypeVar('T') + P = ParamSpec('P') + MaybeCoroFunc = Union[ + Callable[P, 'Coro[T]'], + Callable[P, T], + ] +else: + P = TypeVar('P') + MaybeCoroFunc = Tuple[P, T] Coro = Coroutine[Any, Any, T] MaybeCoro = Union[T, Coro[T]] CoroFunc = Callable[..., Coro[Any]] ContextT = TypeVar('ContextT', bound='Context') +_Bot = Union['Bot', 'AutoShardedBot'] +BotT = TypeVar('BotT', bound=_Bot) Check = Union[Callable[["Cog", "ContextT"], MaybeCoro[bool]], Callable[["ContextT"], MaybeCoro[bool]]] Hook = Union[Callable[["Cog", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]] diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 1ee2709ce..f74ce3f5b 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -33,7 +33,21 @@ import importlib.util import sys import traceback import types -from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union, overload +from typing import ( + Any, + Callable, + Mapping, + List, + Dict, + TYPE_CHECKING, + Optional, + TypeVar, + Type, + Union, + Iterable, + Collection, + overload, +) import discord from discord import app_commands @@ -55,10 +69,18 @@ if TYPE_CHECKING: from discord.message import Message from discord.abc import User, Snowflake from ._types import ( + _Bot, + BotT, Check, CoroFunc, + ContextT, + MaybeCoroFunc, ) + _Prefix = Union[Iterable[str], str] + _PrefixCallable = MaybeCoroFunc[[BotT, Message], _Prefix] + PrefixType = Union[_Prefix, _PrefixCallable[BotT]] + __all__ = ( 'when_mentioned', 'when_mentioned_or', @@ -68,11 +90,9 @@ __all__ = ( T = TypeVar('T') CFT = TypeVar('CFT', bound='CoroFunc') -CXT = TypeVar('CXT', bound='Context') -BT = TypeVar('BT', bound='Union[Bot, AutoShardedBot]') -def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: +def when_mentioned(bot: _Bot, msg: Message) -> List[str]: """A callable that implements a command prefix equivalent to being mentioned. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. @@ -81,7 +101,7 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore -def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]: +def when_mentioned_or(*prefixes: str) -> Callable[[_Bot, Message], List[str]]: """A callable that implements when mentioned or other prefixes provided. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. @@ -124,27 +144,33 @@ class _DefaultRepr: return '' -_default = _DefaultRepr() +_default: Any = _DefaultRepr() -class BotBase(GroupMixin): - def __init__(self, command_prefix, help_command=_default, description=None, **options): +class BotBase(GroupMixin[None]): + def __init__( + self, + command_prefix: PrefixType[BotT], + help_command: HelpCommand = _default, + description: Optional[str] = None, + **options: Any, + ) -> None: super().__init__(**options) - self.command_prefix = command_prefix + self.command_prefix: PrefixType[BotT] = command_prefix self.extra_events: Dict[str, List[CoroFunc]] = {} # Self doesn't have the ClientT bound, but since this is a mixin it technically does self.__tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) # type: ignore self.__cogs: Dict[str, Cog] = {} self.__extensions: Dict[str, types.ModuleType] = {} self._checks: List[Check] = [] - self._check_once = [] - self._before_invoke = None - self._after_invoke = None - self._help_command = None - self.description = inspect.cleandoc(description) if description else '' - self.owner_id = options.get('owner_id') - self.owner_ids = options.get('owner_ids', set()) - self.strip_after_prefix = options.get('strip_after_prefix', False) + self._check_once: List[Check] = [] + self._before_invoke: Optional[CoroFunc] = None + self._after_invoke: Optional[CoroFunc] = None + self._help_command: Optional[HelpCommand] = None + self.description: str = inspect.cleandoc(description) if description else '' + self.owner_id: Optional[int] = options.get('owner_id') + self.owner_ids: Optional[Collection[int]] = options.get('owner_ids', set()) + self.strip_after_prefix: bool = options.get('strip_after_prefix', False) if self.owner_id and self.owner_ids: raise TypeError('Both owner_id and owner_ids are set.') @@ -182,7 +208,7 @@ class BotBase(GroupMixin): await super().close() # type: ignore - async def on_command_error(self, context: Context, exception: errors.CommandError) -> None: + async def on_command_error(self, context: Context[BotT], exception: errors.CommandError) -> None: """|coro| The default command error handler provided by the bot. @@ -237,7 +263,7 @@ class BotBase(GroupMixin): self.add_check(func) # type: ignore return func - def add_check(self, func: Check, /, *, call_once: bool = False) -> None: + def add_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None: """Adds a global check to the bot. This is the non-decorator interface to :meth:`.check` @@ -261,7 +287,7 @@ class BotBase(GroupMixin): else: self._checks.append(func) - def remove_check(self, func: Check, /, *, call_once: bool = False) -> None: + def remove_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None: """Removes a global check from the bot. This function is idempotent and will not raise an exception @@ -324,7 +350,7 @@ class BotBase(GroupMixin): self.add_check(func, call_once=True) return func - async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool: + async def can_run(self, ctx: Context[BotT], *, call_once: bool = False) -> bool: data = self._check_once if call_once else self._checks if len(data) == 0: @@ -947,7 +973,7 @@ class BotBase(GroupMixin): # if the load failed, the remnants should have been # cleaned from the load_extension function call # so let's load it from our old compiled library. - await lib.setup(self) # type: ignore + await lib.setup(self) self.__extensions[name] = lib # revert sys.modules back to normal and raise back to caller @@ -1015,11 +1041,12 @@ class BotBase(GroupMixin): """ prefix = ret = self.command_prefix if callable(prefix): - ret = await discord.utils.maybe_coroutine(prefix, self, message) + # self will be a Bot or AutoShardedBot + ret = await discord.utils.maybe_coroutine(prefix, self, message) # type: ignore if not isinstance(ret, str): try: - ret = list(ret) + ret = list(ret) # type: ignore except TypeError: # It's possible that a generator raised this exception. Don't # replace it with our own error if that's the case. @@ -1048,15 +1075,15 @@ class BotBase(GroupMixin): self, message: Message, *, - cls: Type[CXT] = ..., - ) -> CXT: # type: ignore + cls: Type[ContextT] = ..., + ) -> ContextT: ... async def get_context( self, message: Message, *, - cls: Type[CXT] = MISSING, + cls: Type[ContextT] = MISSING, ) -> Any: r"""|coro| @@ -1137,7 +1164,7 @@ class BotBase(GroupMixin): ctx.command = self.all_commands.get(invoker) return ctx - async def invoke(self, ctx: Context) -> None: + async def invoke(self, ctx: Context[BotT]) -> None: """|coro| Invokes the command given under the invocation context and @@ -1189,9 +1216,10 @@ class BotBase(GroupMixin): return ctx = await self.get_context(message) - await self.invoke(ctx) + # the type of the invocation context's bot attribute will be correct + await self.invoke(ctx) # type: ignore - async def on_message(self, message): + async def on_message(self, message: Message) -> None: await self.process_commands(message) diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index d16584a39..59eb54c4e 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -30,7 +30,7 @@ from discord.utils import maybe_coroutine from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union -from ._types import _BaseCommand +from ._types import _BaseCommand, BotT if TYPE_CHECKING: from typing_extensions import Self @@ -112,7 +112,7 @@ class CogMeta(type): __cog_name__: str __cog_settings__: Dict[str, Any] - __cog_commands__: List[Command] + __cog_commands__: List[Command[Any, ..., Any]] __cog_is_app_commands_group__: bool __cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]] __cog_listeners__: List[Tuple[str, str]] @@ -406,7 +406,7 @@ class Cog(metaclass=CogMeta): pass @_cog_special_method - def bot_check_once(self, ctx: Context) -> bool: + def bot_check_once(self, ctx: Context[BotT]) -> bool: """A special method that registers as a :meth:`.Bot.check_once` check. @@ -416,7 +416,7 @@ class Cog(metaclass=CogMeta): return True @_cog_special_method - def bot_check(self, ctx: Context) -> bool: + def bot_check(self, ctx: Context[BotT]) -> bool: """A special method that registers as a :meth:`.Bot.check` check. @@ -426,7 +426,7 @@ class Cog(metaclass=CogMeta): return True @_cog_special_method - def cog_check(self, ctx: Context) -> bool: + def cog_check(self, ctx: Context[BotT]) -> bool: """A special method that registers as a :func:`~discord.ext.commands.check` for every command and subcommand in this cog. @@ -436,7 +436,7 @@ class Cog(metaclass=CogMeta): return True @_cog_special_method - async def cog_command_error(self, ctx: Context, error: Exception) -> None: + async def cog_command_error(self, ctx: Context[BotT], error: Exception) -> None: """A special method that is called whenever an error is dispatched inside this cog. @@ -455,7 +455,7 @@ class Cog(metaclass=CogMeta): pass @_cog_special_method - async def cog_before_invoke(self, ctx: Context) -> None: + async def cog_before_invoke(self, ctx: Context[BotT]) -> None: """A special method that acts as a cog local pre-invoke hook. This is similar to :meth:`.Command.before_invoke`. @@ -470,7 +470,7 @@ class Cog(metaclass=CogMeta): pass @_cog_special_method - async def cog_after_invoke(self, ctx: Context) -> None: + async def cog_after_invoke(self, ctx: Context[BotT]) -> None: """A special method that acts as a cog local post-invoke hook. This is similar to :meth:`.Command.after_invoke`. diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 8e5d59851..8acfa09e5 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -28,6 +28,8 @@ import re from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union +from ._types import BotT + import discord.abc import discord.utils @@ -59,7 +61,6 @@ MISSING: Any = discord.utils.MISSING T = TypeVar('T') -BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]") CogT = TypeVar('CogT', bound="Cog") if TYPE_CHECKING: @@ -133,10 +134,10 @@ class Context(discord.abc.Messageable, Generic[BotT]): args: List[Any] = MISSING, kwargs: Dict[str, Any] = MISSING, prefix: Optional[str] = None, - command: Optional[Command] = None, + command: Optional[Command[Any, ..., Any]] = None, invoked_with: Optional[str] = None, invoked_parents: List[str] = MISSING, - invoked_subcommand: Optional[Command] = None, + invoked_subcommand: Optional[Command[Any, ..., Any]] = None, subcommand_passed: Optional[str] = None, command_failed: bool = False, current_parameter: Optional[inspect.Parameter] = None, @@ -146,11 +147,11 @@ class Context(discord.abc.Messageable, Generic[BotT]): self.args: List[Any] = args or [] self.kwargs: Dict[str, Any] = kwargs or {} self.prefix: Optional[str] = prefix - self.command: Optional[Command] = command + self.command: Optional[Command[Any, ..., Any]] = command self.view: StringView = view self.invoked_with: Optional[str] = invoked_with self.invoked_parents: List[str] = invoked_parents or [] - self.invoked_subcommand: Optional[Command] = invoked_subcommand + self.invoked_subcommand: Optional[Command[Any, ..., Any]] = invoked_subcommand self.subcommand_passed: Optional[str] = subcommand_passed self.command_failed: bool = command_failed self.current_parameter: Optional[inspect.Parameter] = current_parameter @@ -361,7 +362,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): return None cmd = cmd.copy() - cmd.context = self + cmd.context = self # type: ignore if len(args) == 0: await cmd.prepare_help_command(self, None) mapping = cmd.get_bot_mapping() @@ -390,7 +391,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): try: if hasattr(entity, '__cog_commands__'): injected = wrap_callback(cmd.send_cog_help) - return await injected(entity) + return await injected(entity) # type: ignore elif isinstance(entity, Group): injected = wrap_callback(cmd.send_group_help) return await injected(entity) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index fce4a3021..e9a2ea758 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -41,7 +41,6 @@ from typing import ( Tuple, Union, runtime_checkable, - overload, ) import discord @@ -51,9 +50,8 @@ if TYPE_CHECKING: from .context import Context from discord.state import Channel from discord.threads import Thread - from .bot import Bot, AutoShardedBot - _Bot = TypeVar('_Bot', bound=Union[Bot, AutoShardedBot]) + from ._types import BotT, _Bot __all__ = ( @@ -87,7 +85,7 @@ __all__ = ( ) -def _get_from_guilds(bot, getter, argument): +def _get_from_guilds(bot: _Bot, getter: str, argument: Any) -> Any: result = None for guild in bot.guilds: result = getattr(guild, getter)(argument) @@ -115,7 +113,7 @@ class Converter(Protocol[T_co]): method to do its conversion logic. This method must be a :ref:`coroutine `. """ - async def convert(self, ctx: Context, argument: str) -> T_co: + async def convert(self, ctx: Context[BotT], argument: str) -> T_co: """|coro| The method to override to do conversion logic. @@ -163,7 +161,7 @@ class ObjectConverter(IDConverter[discord.Object]): 2. Lookup by member, role, or channel mention. """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Object: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Object: match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument) if match is None: @@ -196,7 +194,7 @@ class MemberConverter(IDConverter[discord.Member]): optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled. """ - async def query_member_named(self, guild, argument): + async def query_member_named(self, guild: discord.Guild, argument: str) -> Optional[discord.Member]: cache = guild._state.member_cache_flags.joined if len(argument) > 5 and argument[-5] == '#': username, _, discriminator = argument.rpartition('#') @@ -206,7 +204,7 @@ class MemberConverter(IDConverter[discord.Member]): members = await guild.query_members(argument, limit=100, cache=cache) return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members) - async def query_member_by_id(self, bot, guild, user_id): + async def query_member_by_id(self, bot: _Bot, guild: discord.Guild, user_id: int) -> Optional[discord.Member]: ws = bot._get_websocket(shard_id=guild.shard_id) cache = guild._state.member_cache_flags.joined if ws.is_ratelimited(): @@ -227,7 +225,7 @@ class MemberConverter(IDConverter[discord.Member]): return None return members[0] - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Member: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Member: bot = ctx.bot match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) guild = ctx.guild @@ -281,7 +279,7 @@ class UserConverter(IDConverter[discord.User]): and it's not available in cache. """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.User: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.User: match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) result = None state = ctx._state @@ -359,7 +357,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): @staticmethod def _resolve_channel( - ctx: Context[_Bot], guild_id: Optional[int], channel_id: Optional[int] + ctx: Context[BotT], guild_id: Optional[int], channel_id: Optional[int] ) -> Optional[Union[Channel, Thread]]: if channel_id is None: # we were passed just a message id so we can assume the channel is the current context channel @@ -373,7 +371,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): return ctx.bot.get_channel(channel_id) - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialMessage: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialMessage: guild_id, message_id, channel_id = self._get_id_matches(ctx, argument) channel = self._resolve_channel(ctx, guild_id, channel_id) if not channel or not isinstance(channel, discord.abc.Messageable): @@ -396,7 +394,7 @@ class MessageConverter(IDConverter[discord.Message]): Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Message: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Message: guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument) message = ctx.bot._connection._get_message(message_id) if message: @@ -427,11 +425,11 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): .. versionadded:: 2.0 """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.abc.GuildChannel: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.abc.GuildChannel: return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel) @staticmethod - def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT: + def _resolve_channel(ctx: Context[BotT], argument: str, attribute: str, type: Type[CT]) -> CT: bot = ctx.bot match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument) @@ -448,7 +446,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): def check(c): return isinstance(c, type) and c.name == argument - result = discord.utils.find(check, bot.get_all_channels()) + result = discord.utils.find(check, bot.get_all_channels()) # type: ignore else: channel_id = int(match.group(1)) if guild: @@ -463,7 +461,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): return result @staticmethod - def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT: + def _resolve_thread(ctx: Context[BotT], argument: str, attribute: str, type: Type[TT]) -> TT: bot = ctx.bot match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument) @@ -502,7 +500,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.TextChannel: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.TextChannel: return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel) @@ -522,7 +520,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.VoiceChannel: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.VoiceChannel: return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel) @@ -541,7 +539,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): 3. Lookup by name """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StageChannel: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.StageChannel: return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel) @@ -561,7 +559,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.CategoryChannel: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.CategoryChannel: return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel) @@ -580,7 +578,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): .. versionadded:: 1.7 """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StoreChannel: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.StoreChannel: return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel) @@ -598,7 +596,7 @@ class ThreadConverter(IDConverter[discord.Thread]): .. versionadded: 2.0 """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Thread: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Thread: return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread) @@ -630,7 +628,7 @@ class ColourConverter(Converter[discord.Colour]): RGB_REGEX = re.compile(r'rgb\s*\((?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*\)') - def parse_hex_number(self, argument): + def parse_hex_number(self, argument: str) -> discord.Colour: arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument try: value = int(arg, base=16) @@ -641,7 +639,7 @@ class ColourConverter(Converter[discord.Colour]): else: return discord.Color(value=value) - def parse_rgb_number(self, argument, number): + def parse_rgb_number(self, argument: str, number: str) -> int: if number[-1] == '%': value = int(number[:-1]) if not (0 <= value <= 100): @@ -653,7 +651,7 @@ class ColourConverter(Converter[discord.Colour]): raise BadColourArgument(argument) return value - def parse_rgb(self, argument, *, regex=RGB_REGEX): + def parse_rgb(self, argument: str, *, regex: re.Pattern[str] = RGB_REGEX) -> discord.Colour: match = regex.match(argument) if match is None: raise BadColourArgument(argument) @@ -663,7 +661,7 @@ class ColourConverter(Converter[discord.Colour]): blue = self.parse_rgb_number(argument, match.group('b')) return discord.Color.from_rgb(red, green, blue) - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Colour: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Colour: if argument[0] == '#': return self.parse_hex_number(argument[1:]) @@ -704,7 +702,7 @@ class RoleConverter(IDConverter[discord.Role]): Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Role: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Role: guild = ctx.guild if not guild: raise NoPrivateMessage() @@ -723,7 +721,7 @@ class RoleConverter(IDConverter[discord.Role]): class GameConverter(Converter[discord.Game]): """Converts to :class:`~discord.Game`.""" - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Game: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Game: return discord.Game(name=argument) @@ -736,7 +734,7 @@ class InviteConverter(Converter[discord.Invite]): Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Invite: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Invite: try: invite = await ctx.bot.fetch_invite(argument) return invite @@ -755,7 +753,7 @@ class GuildConverter(IDConverter[discord.Guild]): .. versionadded:: 1.7 """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Guild: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Guild: match = self._get_id_match(argument) result = None @@ -787,7 +785,7 @@ class EmojiConverter(IDConverter[discord.Emoji]): Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Emoji: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Emoji: match = self._get_id_match(argument) or re.match(r'$', argument) result = None bot = ctx.bot @@ -821,7 +819,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]): Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialEmoji: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialEmoji: match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument) if match: @@ -850,7 +848,7 @@ class GuildStickerConverter(IDConverter[discord.GuildSticker]): .. versionadded:: 2.0 """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.GuildSticker: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.GuildSticker: match = self._get_id_match(argument) result = None bot = ctx.bot @@ -890,7 +888,7 @@ class ScheduledEventConverter(IDConverter[discord.ScheduledEvent]): .. versionadded:: 2.0 """ - async def convert(self, ctx: Context[_Bot], argument: str) -> discord.ScheduledEvent: + async def convert(self, ctx: Context[BotT], argument: str) -> discord.ScheduledEvent: guild = ctx.guild match = self._get_id_match(argument) result = None @@ -967,7 +965,7 @@ class clean_content(Converter[str]): self.escape_markdown = escape_markdown self.remove_markdown = remove_markdown - async def convert(self, ctx: Context[_Bot], argument: str) -> str: + async def convert(self, ctx: Context[BotT], argument: str) -> str: msg = ctx.message if ctx.guild: @@ -1047,10 +1045,10 @@ class Greedy(List[T]): __slots__ = ('converter',) - def __init__(self, *, converter: T): - self.converter = converter + def __init__(self, *, converter: T) -> None: + self.converter: T = converter - def __repr__(self): + def __repr__(self) -> str: converter = getattr(self.converter, '__name__', repr(self.converter)) return f'Greedy[{converter}]' @@ -1099,11 +1097,11 @@ def get_converter(param: inspect.Parameter) -> Any: _GenericAlias = type(List[T]) -def is_generic_type(tp: Any, *, _GenericAlias: Type = _GenericAlias) -> bool: - return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias) # type: ignore +def is_generic_type(tp: Any, *, _GenericAlias: type = _GenericAlias) -> bool: + return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias) -CONVERTER_MAPPING: Dict[Type[Any], Any] = { +CONVERTER_MAPPING: Dict[type, Any] = { discord.Object: ObjectConverter, discord.Member: MemberConverter, discord.User: UserConverter, @@ -1128,7 +1126,7 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = { } -async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter): +async def _actual_conversion(ctx: Context[BotT], converter, argument: str, param: inspect.Parameter): if converter is bool: return _convert_to_bool(argument) @@ -1166,7 +1164,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc -async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter): +async def run_converters(ctx: Context[BotT], converter: Any, argument: str, param: inspect.Parameter) -> Any: """|coro| Runs converters for a given converter, argument, and parameter. diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index e188712b3..875ef145f 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -220,7 +220,7 @@ class CooldownMapping: return self._type @classmethod - def from_cooldown(cls, rate, per, type) -> Self: + def from_cooldown(cls, rate: float, per: float, type: Callable[[Message], Any]) -> Self: return cls(Cooldown(rate, per), type) def _bucket_key(self, msg: Message) -> Any: diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 57eeba70b..44c32db0e 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -61,6 +61,8 @@ if TYPE_CHECKING: from discord.message import Message from ._types import ( + BotT, + ContextT, Coro, CoroFunc, Check, @@ -101,7 +103,6 @@ MISSING: Any = discord.utils.MISSING T = TypeVar('T') CogT = TypeVar('CogT', bound='Optional[Cog]') CommandT = TypeVar('CommandT', bound='Command') -ContextT = TypeVar('ContextT', bound='Context') # CHT = TypeVar('CHT', bound='Check') GroupT = TypeVar('GroupT', bound='Group') FuncT = TypeVar('FuncT', bound=Callable[..., Any]) @@ -159,9 +160,9 @@ def get_signature_parameters( return params -def wrap_callback(coro): +def wrap_callback(coro: Callable[P, Coro[T]]) -> Callable[P, Coro[Optional[T]]]: @functools.wraps(coro) - async def wrapped(*args, **kwargs): + async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]: try: ret = await coro(*args, **kwargs) except CommandError: @@ -175,9 +176,11 @@ def wrap_callback(coro): return wrapped -def hooked_wrapped_callback(command, ctx, coro): +def hooked_wrapped_callback( + command: Command[Any, ..., Any], ctx: Context[BotT], coro: Callable[P, Coro[T]] +) -> Callable[P, Coro[Optional[T]]]: @functools.wraps(coro) - async def wrapped(*args, **kwargs): + async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]: try: ret = await coro(*args, **kwargs) except CommandError: @@ -191,7 +194,7 @@ def hooked_wrapped_callback(command, ctx, coro): raise CommandInvokeError(exc) from exc finally: if command._max_concurrency is not None: - await command._max_concurrency.release(ctx) + await command._max_concurrency.release(ctx.message) await command.call_after_hooks(ctx) return ret @@ -359,7 +362,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): except AttributeError: checks = kwargs.get('checks', []) - self.checks: List[Check] = checks + self.checks: List[Check[ContextT]] = checks try: cooldown = func.__commands_cooldown__ @@ -387,8 +390,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.cog: CogT = None # bandaid for the fact that sometimes parent can be the bot instance - parent = kwargs.get('parent') - self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore + parent: Optional[GroupMixin[Any]] = kwargs.get('parent') + self.parent: Optional[GroupMixin[Any]] = parent if isinstance(parent, _BaseCommand) else None self._before_invoke: Optional[Hook] = None try: @@ -422,16 +425,16 @@ class Command(_BaseCommand, Generic[CogT, P, T]): ) -> None: self._callback = function unwrap = unwrap_function(function) - self.module = unwrap.__module__ + self.module: str = unwrap.__module__ try: globalns = unwrap.__globals__ except AttributeError: globalns = {} - self.params = get_signature_parameters(function, globalns) + self.params: Dict[str, inspect.Parameter] = get_signature_parameters(function, globalns) - def add_check(self, func: Check, /) -> None: + def add_check(self, func: Check[ContextT], /) -> None: """Adds a check to the command. This is the non-decorator interface to :func:`.check`. @@ -450,7 +453,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.checks.append(func) - def remove_check(self, func: Check, /) -> None: + def remove_check(self, func: Check[ContextT], /) -> None: """Removes a check from the command. This function is idempotent and will not raise an exception @@ -484,7 +487,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs)) self.cog = cog - async def __call__(self, context: Context, *args: P.args, **kwargs: P.kwargs) -> T: + async def __call__(self, context: Context[BotT], *args: P.args, **kwargs: P.kwargs) -> T: """|coro| Calls the internal callback that the command holds. @@ -539,7 +542,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): else: return self.copy() - async def dispatch_error(self, ctx: Context, error: Exception) -> None: + async def dispatch_error(self, ctx: Context[BotT], error: CommandError) -> None: ctx.command_failed = True cog = self.cog try: @@ -549,7 +552,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): else: injected = wrap_callback(coro) if cog is not None: - await injected(cog, ctx, error) + await injected(cog, ctx, error) # type: ignore else: await injected(ctx, error) @@ -562,7 +565,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): finally: ctx.bot.dispatch('command_error', ctx, error) - async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: + async def transform(self, ctx: Context[BotT], param: inspect.Parameter) -> Any: required = param.default is param.empty converter = get_converter(param) consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw @@ -610,7 +613,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): # type-checker fails to narrow argument return await run_converters(ctx, converter, argument, param) # type: ignore - async def _transform_greedy_pos(self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any) -> Any: + async def _transform_greedy_pos( + self, ctx: Context[BotT], param: inspect.Parameter, required: bool, converter: Any + ) -> Any: view = ctx.view result = [] while not view.eof: @@ -631,7 +636,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): return param.default return result - async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any: + async def _transform_greedy_var_pos(self, ctx: Context[BotT], param: inspect.Parameter, converter: Any) -> Any: view = ctx.view previous = view.index try: @@ -669,7 +674,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): return ' '.join(reversed(entries)) @property - def parents(self) -> List[Group]: + def parents(self) -> List[Group[Any, ..., Any]]: """List[:class:`Group`]: Retrieves the parents of this command. If the command has no parents then it returns an empty :class:`list`. @@ -687,7 +692,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): return entries @property - def root_parent(self) -> Optional[Group]: + def root_parent(self) -> Optional[Group[Any, ..., Any]]: """Optional[:class:`Group`]: Retrieves the root parent of this command. If the command has no parents then it returns ``None``. @@ -716,7 +721,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): def __str__(self) -> str: return self.qualified_name - async def _parse_arguments(self, ctx: Context) -> None: + async def _parse_arguments(self, ctx: Context[BotT]) -> None: ctx.args = [ctx] if self.cog is None else [self.cog, ctx] ctx.kwargs = {} args = ctx.args @@ -752,7 +757,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if not self.ignore_extra and not view.eof: raise TooManyArguments('Too many arguments passed to ' + self.qualified_name) - async def call_before_hooks(self, ctx: Context) -> None: + async def call_before_hooks(self, ctx: Context[BotT]) -> None: # now that we're done preparing we can call the pre-command hooks # first, call the command local hook: cog = self.cog @@ -777,7 +782,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if hook is not None: await hook(ctx) - async def call_after_hooks(self, ctx: Context) -> None: + async def call_after_hooks(self, ctx: Context[BotT]) -> None: cog = self.cog if self._after_invoke is not None: instance = getattr(self._after_invoke, '__self__', cog) @@ -796,7 +801,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if hook is not None: await hook(ctx) - def _prepare_cooldowns(self, ctx: Context) -> None: + def _prepare_cooldowns(self, ctx: Context[BotT]) -> None: if self._buckets.valid: dt = ctx.message.edited_at or ctx.message.created_at current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() @@ -806,7 +811,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if retry_after: raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore - async def prepare(self, ctx: Context) -> None: + async def prepare(self, ctx: Context[BotT]) -> None: ctx.command = self if not await self.can_run(ctx): @@ -830,7 +835,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): await self._max_concurrency.release(ctx) # type: ignore raise - def is_on_cooldown(self, ctx: Context) -> bool: + def is_on_cooldown(self, ctx: Context[BotT]) -> bool: """Checks whether the command is currently on cooldown. Parameters @@ -851,7 +856,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() return bucket.get_tokens(current) == 0 - def reset_cooldown(self, ctx: Context) -> None: + def reset_cooldown(self, ctx: Context[BotT]) -> None: """Resets the cooldown on this command. Parameters @@ -863,7 +868,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): bucket = self._buckets.get_bucket(ctx.message) bucket.reset() - def get_cooldown_retry_after(self, ctx: Context) -> float: + def get_cooldown_retry_after(self, ctx: Context[BotT]) -> float: """Retrieves the amount of seconds before this command can be tried again. .. versionadded:: 1.4 @@ -887,7 +892,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): return 0.0 - async def invoke(self, ctx: Context) -> None: + async def invoke(self, ctx: Context[BotT]) -> None: await self.prepare(ctx) # terminate the invoked_subcommand chain. @@ -896,9 +901,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): ctx.invoked_subcommand = None ctx.subcommand_passed = None injected = hooked_wrapped_callback(self, ctx, self.callback) - await injected(*ctx.args, **ctx.kwargs) + await injected(*ctx.args, **ctx.kwargs) # type: ignore - async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: + async def reinvoke(self, ctx: Context[BotT], *, call_hooks: bool = False) -> None: ctx.command = self await self._parse_arguments(ctx) @@ -936,7 +941,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if not asyncio.iscoroutinefunction(coro): raise TypeError('The error handler must be a coroutine.') - self.on_error: Error = coro + self.on_error: Error[Any] = coro return coro def has_error_handler(self) -> bool: @@ -1075,7 +1080,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): return ' '.join(result) - async def can_run(self, ctx: Context) -> bool: + async def can_run(self, ctx: Context[BotT]) -> bool: """|coro| Checks if the command can be executed by checking all the predicates @@ -1341,7 +1346,7 @@ class GroupMixin(Generic[CogT]): def command( self, name: str = MISSING, - cls: Type[Command] = MISSING, + cls: Type[Command[Any, ..., Any]] = MISSING, *args: Any, **kwargs: Any, ) -> Any: @@ -1401,7 +1406,7 @@ class GroupMixin(Generic[CogT]): def group( self, name: str = MISSING, - cls: Type[Group] = MISSING, + cls: Type[Group[Any, ..., Any]] = MISSING, *args: Any, **kwargs: Any, ) -> Any: @@ -1461,9 +1466,9 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]): ret = super().copy() for cmd in self.commands: ret.add_command(cmd.copy()) - return ret # type: ignore + return ret - async def invoke(self, ctx: Context) -> None: + async def invoke(self, ctx: Context[BotT]) -> None: ctx.invoked_subcommand = None ctx.subcommand_passed = None early_invoke = not self.invoke_without_command @@ -1481,7 +1486,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]): if early_invoke: injected = hooked_wrapped_callback(self, ctx, self.callback) - await injected(*ctx.args, **ctx.kwargs) + await injected(*ctx.args, **ctx.kwargs) # type: ignore ctx.invoked_parents.append(ctx.invoked_with) # type: ignore @@ -1494,7 +1499,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]): view.previous = previous await super().invoke(ctx) - async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: + async def reinvoke(self, ctx: Context[BotT], *, call_hooks: bool = False) -> None: ctx.invoked_subcommand = None early_invoke = not self.invoke_without_command if early_invoke: @@ -1592,7 +1597,7 @@ def command( def command( name: str = MISSING, - cls: Type[Command] = MISSING, + cls: Type[Command[Any, ..., Any]] = MISSING, **attrs: Any, ) -> Any: """A decorator that transforms a function into a :class:`.Command` @@ -1662,7 +1667,7 @@ def group( def group( name: str = MISSING, - cls: Type[Group] = MISSING, + cls: Type[Group[Any, ..., Any]] = MISSING, **attrs: Any, ) -> Any: """A decorator that transforms a function into a :class:`.Group`. @@ -1679,7 +1684,7 @@ def group( return command(name=name, cls=cls, **attrs) -def check(predicate: Check) -> Callable[[T], T]: +def check(predicate: Check[ContextT]) -> Callable[[T], T]: r"""A decorator that adds a check to the :class:`.Command` or its subclasses. These checks could be accessed via :attr:`.Command.checks`. @@ -1774,7 +1779,7 @@ def check(predicate: Check) -> Callable[[T], T]: return decorator # type: ignore -def check_any(*checks: Check) -> Callable[[T], T]: +def check_any(*checks: Check[ContextT]) -> Callable[[T], T]: r"""A :func:`check` that is added that checks if any of the checks passed will pass, i.e. using logical OR. @@ -1827,7 +1832,7 @@ def check_any(*checks: Check) -> Callable[[T], T]: else: unwrapped.append(pred) - async def predicate(ctx: Context) -> bool: + async def predicate(ctx: Context[BotT]) -> bool: errors = [] for func in unwrapped: try: @@ -1870,7 +1875,7 @@ def has_role(item: Union[int, str]) -> Callable[[T], T]: The name or ID of the role to check. """ - def predicate(ctx: Context) -> bool: + def predicate(ctx: Context[BotT]) -> bool: if ctx.guild is None: raise NoPrivateMessage() @@ -1923,7 +1928,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: raise NoPrivateMessage() # ctx.guild is None doesn't narrow ctx.author to Member - getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore + getter = functools.partial(discord.utils.get, ctx.author.roles) if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items): return True raise MissingAnyRole(list(items)) @@ -2022,7 +2027,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]: if invalid: raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") - def predicate(ctx: Context) -> bool: + def predicate(ctx: Context[BotT]) -> bool: ch = ctx.channel permissions = ch.permissions_for(ctx.author) # type: ignore @@ -2048,7 +2053,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]: if invalid: raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") - def predicate(ctx: Context) -> bool: + def predicate(ctx: Context[BotT]) -> bool: guild = ctx.guild me = guild.me if guild is not None else ctx.bot.user permissions = ctx.channel.permissions_for(me) # type: ignore @@ -2077,7 +2082,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]: if invalid: raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") - def predicate(ctx: Context) -> bool: + def predicate(ctx: Context[BotT]) -> bool: if not ctx.guild: raise NoPrivateMessage @@ -2103,7 +2108,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: if invalid: raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") - def predicate(ctx: Context) -> bool: + def predicate(ctx: Context[BotT]) -> bool: if not ctx.guild: raise NoPrivateMessage @@ -2129,7 +2134,7 @@ def dm_only() -> Callable[[T], T]: .. versionadded:: 1.1 """ - def predicate(ctx: Context) -> bool: + def predicate(ctx: Context[BotT]) -> bool: if ctx.guild is not None: raise PrivateMessageOnly() return True @@ -2146,7 +2151,7 @@ def guild_only() -> Callable[[T], T]: that is inherited from :exc:`.CheckFailure`. """ - def predicate(ctx: Context) -> bool: + def predicate(ctx: Context[BotT]) -> bool: if ctx.guild is None: raise NoPrivateMessage() return True @@ -2164,7 +2169,7 @@ def is_owner() -> Callable[[T], T]: from :exc:`.CheckFailure`. """ - async def predicate(ctx: Context) -> bool: + async def predicate(ctx: Context[BotT]) -> bool: if not await ctx.bot.is_owner(ctx.author): raise NotOwner('You do not own this bot.') return True @@ -2184,7 +2189,7 @@ def is_nsfw() -> Callable[[T], T]: DM channels will also now pass this check. """ - def pred(ctx: Context) -> bool: + def pred(ctx: Context[BotT]) -> bool: ch = ctx.channel if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()): return True diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 2047759e0..2b0567b5b 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -39,6 +39,8 @@ if TYPE_CHECKING: from discord.threads import Thread from discord.types.snowflake import Snowflake, SnowflakeList + from ._types import BotT + __all__ = ( 'CommandError', @@ -135,8 +137,8 @@ class ConversionError(CommandError): the ``__cause__`` attribute. """ - def __init__(self, converter: Converter, original: Exception) -> None: - self.converter: Converter = converter + def __init__(self, converter: Converter[Any], original: Exception) -> None: + self.converter: Converter[Any] = converter self.original: Exception = original @@ -224,9 +226,9 @@ class CheckAnyFailure(CheckFailure): A list of check predicates that failed. """ - def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None: + def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context[BotT]], bool]]) -> None: self.checks: List[CheckFailure] = checks - self.errors: List[Callable[[Context], bool]] = errors + self.errors: List[Callable[[Context[BotT]], bool]] = errors super().__init__('You do not have permission to run this command.') @@ -807,9 +809,9 @@ class BadUnionArgument(UserInputError): A list of errors that were caught from failing the conversion. """ - def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None: + def __init__(self, param: Parameter, converters: Tuple[type, ...], errors: List[CommandError]) -> None: self.param: Parameter = param - self.converters: Tuple[Type, ...] = converters + self.converters: Tuple[type, ...] = converters self.errors: List[CommandError] = errors def _get_name(x): diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py index 041736913..8de2f237c 100644 --- a/discord/ext/commands/flags.py +++ b/discord/ext/commands/flags.py @@ -49,8 +49,6 @@ from typing import ( Tuple, List, Any, - Type, - TypeVar, Union, ) @@ -70,6 +68,8 @@ if TYPE_CHECKING: from .context import Context + from ._types import BotT + @dataclass class Flag: @@ -148,7 +148,7 @@ def flag( return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override) -def validate_flag_name(name: str, forbidden: Set[str]): +def validate_flag_name(name: str, forbidden: Set[str]) -> None: if not name: raise ValueError('flag names should not be empty') @@ -348,7 +348,7 @@ class FlagsMeta(type): return type.__new__(cls, name, bases, attrs) -async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]: +async def tuple_convert_all(ctx: Context[BotT], argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]: view = StringView(argument) results = [] param: inspect.Parameter = ctx.current_parameter # type: ignore @@ -373,7 +373,7 @@ async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: return tuple(results) -async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]: +async def tuple_convert_flag(ctx: Context[BotT], argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]: view = StringView(argument) results = [] param: inspect.Parameter = ctx.current_parameter # type: ignore @@ -401,7 +401,7 @@ async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters return tuple(results) -async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -> Any: +async def convert_flag(ctx: Context[BotT], argument: str, flag: Flag, annotation: Any = None) -> Any: param: inspect.Parameter = ctx.current_parameter # type: ignore annotation = annotation or flag.annotation try: @@ -480,7 +480,7 @@ class FlagConverter(metaclass=FlagsMeta): yield (flag.name, getattr(self, flag.attribute)) @classmethod - async def _construct_default(cls, ctx: Context) -> Self: + async def _construct_default(cls, ctx: Context[BotT]) -> Self: self = cls.__new__(cls) flags = cls.__commands_flags__ for flag in flags.values(): @@ -546,7 +546,7 @@ class FlagConverter(metaclass=FlagsMeta): return result @classmethod - async def convert(cls, ctx: Context, argument: str) -> Self: + async def convert(cls, ctx: Context[BotT], argument: str) -> Self: """|coro| The method that actually converters an argument to the flag mapping. @@ -610,7 +610,7 @@ class FlagConverter(metaclass=FlagsMeta): values = [await convert_flag(ctx, value, flag) for value in values] if flag.cast_to_dict: - values = dict(values) # type: ignore + values = dict(values) setattr(self, flag.attribute, values) diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index da09b1707..259cf2f9f 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -22,13 +22,27 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import itertools import copy import functools -import inspect import re -from typing import Optional, TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Optional, + Generator, + List, + TypeVar, + Callable, + Any, + Dict, + Tuple, + Iterable, + Sequence, + Mapping, +) import discord.utils @@ -36,7 +50,21 @@ from .core import Group, Command, get_signature_parameters from .errors import CommandError if TYPE_CHECKING: + from typing_extensions import Self + import inspect + + import discord.abc + + from .bot import BotBase from .context import Context + from .cog import Cog + + from ._types import ( + Check, + ContextT, + BotT, + _Bot, + ) __all__ = ( 'Paginator', @@ -45,7 +73,9 @@ __all__ = ( 'MinimalHelpCommand', ) -MISSING = discord.utils.MISSING +FuncT = TypeVar('FuncT', bound=Callable[..., Any]) + +MISSING: Any = discord.utils.MISSING # help -> shows info of bot on top/bottom and lists subcommands # help command -> shows detailed info of command @@ -80,10 +110,10 @@ class Paginator: Attributes ----------- - prefix: :class:`str` - The prefix inserted to every page. e.g. three backticks. - suffix: :class:`str` - The suffix appended at the end of every page. e.g. three backticks. + prefix: Optional[:class:`str`] + The prefix inserted to every page. e.g. three backticks, if any. + suffix: Optional[:class:`str`] + The suffix appended at the end of every page. e.g. three backticks, if any. max_size: :class:`int` The maximum amount of codepoints allowed in a page. linesep: :class:`str` @@ -91,36 +121,38 @@ class Paginator: .. versionadded:: 1.7 """ - def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'): - self.prefix = prefix - self.suffix = suffix - self.max_size = max_size - self.linesep = linesep + def __init__( + self, prefix: Optional[str] = '```', suffix: Optional[str] = '```', max_size: int = 2000, linesep: str = '\n' + ) -> None: + self.prefix: Optional[str] = prefix + self.suffix: Optional[str] = suffix + self.max_size: int = max_size + self.linesep: str = linesep self.clear() - def clear(self): + def clear(self) -> None: """Clears the paginator to have no pages.""" if self.prefix is not None: - self._current_page = [self.prefix] - self._count = len(self.prefix) + self._linesep_len # prefix + newline + self._current_page: List[str] = [self.prefix] + self._count: int = len(self.prefix) + self._linesep_len # prefix + newline else: self._current_page = [] self._count = 0 - self._pages = [] + self._pages: List[str] = [] @property - def _prefix_len(self): + def _prefix_len(self) -> int: return len(self.prefix) if self.prefix else 0 @property - def _suffix_len(self): + def _suffix_len(self) -> int: return len(self.suffix) if self.suffix else 0 @property - def _linesep_len(self): + def _linesep_len(self) -> int: return len(self.linesep) - def add_line(self, line='', *, empty=False): + def add_line(self, line: str = '', *, empty: bool = False) -> None: """Adds a line to the current page. If the line exceeds the :attr:`max_size` then an exception @@ -152,7 +184,7 @@ class Paginator: self._current_page.append('') self._count += self._linesep_len - def close_page(self): + def close_page(self) -> None: """Prematurely terminate a page.""" if self.suffix is not None: self._current_page.append(self.suffix) @@ -165,36 +197,38 @@ class Paginator: self._current_page = [] self._count = 0 - def __len__(self): + def __len__(self) -> int: total = sum(len(p) for p in self._pages) return total + self._count @property - def pages(self): + def pages(self) -> List[str]: """List[:class:`str`]: Returns the rendered list of pages.""" # we have more than just the prefix in our current page if len(self._current_page) > (0 if self.prefix is None else 1): self.close_page() return self._pages - def __repr__(self): + def __repr__(self) -> str: fmt = '' return fmt.format(self) -def _not_overridden(f): +def _not_overridden(f: FuncT) -> FuncT: f.__help_command_not_overridden__ = True return f class _HelpCommandImpl(Command): - def __init__(self, inject, *args, **kwargs): + def __init__(self, inject: HelpCommand, *args: Any, **kwargs: Any) -> None: super().__init__(inject.command_callback, *args, **kwargs) - self._original = inject - self._injected = inject - self.params = get_signature_parameters(inject.command_callback, globals(), skip_parameters=1) + self._original: HelpCommand = inject + self._injected: HelpCommand = inject + self.params: Dict[str, inspect.Parameter] = get_signature_parameters( + inject.command_callback, globals(), skip_parameters=1 + ) - async def prepare(self, ctx): + async def prepare(self, ctx: Context[Any]) -> None: self._injected = injected = self._original.copy() injected.context = ctx self.callback = injected.command_callback @@ -209,7 +243,7 @@ class _HelpCommandImpl(Command): await super().prepare(ctx) - async def _parse_arguments(self, ctx): + async def _parse_arguments(self, ctx: Context[BotT]) -> None: # Make the parser think we don't have a cog so it doesn't # inject the parameter into `ctx.args`. original_cog = self.cog @@ -219,22 +253,26 @@ class _HelpCommandImpl(Command): finally: self.cog = original_cog - async def _on_error_cog_implementation(self, dummy, ctx, error): + async def _on_error_cog_implementation(self, _, ctx: Context[BotT], error: CommandError) -> None: await self._injected.on_help_command_error(ctx, error) - def _inject_into_cog(self, cog): + def _inject_into_cog(self, cog: Cog) -> None: # Warning: hacky # Make the cog think that get_commands returns this command # as well if we inject it without modifying __cog_commands__ # since that's used for the injection and ejection of cogs. - def wrapped_get_commands(*, _original=cog.get_commands): + def wrapped_get_commands( + *, _original: Callable[[], List[Command[Any, ..., Any]]] = cog.get_commands + ) -> List[Command[Any, ..., Any]]: ret = _original() ret.append(self) return ret # Ditto here - def wrapped_walk_commands(*, _original=cog.walk_commands): + def wrapped_walk_commands( + *, _original: Callable[[], Generator[Command[Any, ..., Any], None, None]] = cog.walk_commands + ): yield from _original() yield self @@ -244,7 +282,7 @@ class _HelpCommandImpl(Command): cog.walk_commands = wrapped_walk_commands self.cog = cog - def _eject_cog(self): + def _eject_cog(self) -> None: if self.cog is None: return @@ -298,7 +336,11 @@ class HelpCommand: MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys())) - def __new__(cls, *args, **kwargs): + if TYPE_CHECKING: + __original_kwargs__: Dict[str, Any] + __original_args__: Tuple[Any, ...] + + def __new__(cls, *args: Any, **kwargs: Any) -> Self: # To prevent race conditions of a single instance while also allowing # for settings to be passed the original arguments passed must be assigned # to allow for easier copies (which will be made when the help command is actually called) @@ -314,30 +356,31 @@ class HelpCommand: self.__original_args__ = deepcopy(args) return self - def __init__(self, **options): - self.show_hidden = options.pop('show_hidden', False) - self.verify_checks = options.pop('verify_checks', True) + def __init__(self, **options: Any) -> None: + self.show_hidden: bool = options.pop('show_hidden', False) + self.verify_checks: bool = options.pop('verify_checks', True) + self.command_attrs: Dict[str, Any] self.command_attrs = attrs = options.pop('command_attrs', {}) attrs.setdefault('name', 'help') attrs.setdefault('help', 'Shows this message') - self.context: Context = MISSING + self.context: Context[_Bot] = MISSING self._command_impl = _HelpCommandImpl(self, **self.command_attrs) - def copy(self): + def copy(self) -> Self: obj = self.__class__(*self.__original_args__, **self.__original_kwargs__) obj._command_impl = self._command_impl return obj - def _add_to_bot(self, bot): + def _add_to_bot(self, bot: BotBase) -> None: command = _HelpCommandImpl(self, **self.command_attrs) bot.add_command(command) self._command_impl = command - def _remove_from_bot(self, bot): + def _remove_from_bot(self, bot: BotBase) -> None: bot.remove_command(self._command_impl.name) self._command_impl._eject_cog() - def add_check(self, func, /): + def add_check(self, func: Check[ContextT], /) -> None: """ Adds a check to the help command. @@ -355,7 +398,7 @@ class HelpCommand: self._command_impl.add_check(func) - def remove_check(self, func, /): + def remove_check(self, func: Check[ContextT], /) -> None: """ Removes a check from the help command. @@ -376,15 +419,15 @@ class HelpCommand: self._command_impl.remove_check(func) - def get_bot_mapping(self): + def get_bot_mapping(self) -> Dict[Optional[Cog], List[Command[Any, ..., Any]]]: """Retrieves the bot mapping passed to :meth:`send_bot_help`.""" bot = self.context.bot - mapping = {cog: cog.get_commands() for cog in bot.cogs.values()} + mapping: Dict[Optional[Cog], List[Command[Any, ..., Any]]] = {cog: cog.get_commands() for cog in bot.cogs.values()} mapping[None] = [c for c in bot.commands if c.cog is None] return mapping @property - def invoked_with(self): + def invoked_with(self) -> Optional[str]: """Similar to :attr:`Context.invoked_with` except properly handles the case where :meth:`Context.send_help` is used. @@ -395,7 +438,7 @@ class HelpCommand: Returns --------- - :class:`str` + Optional[:class:`str`] The command name that triggered this invocation. """ command_name = self._command_impl.name @@ -404,7 +447,7 @@ class HelpCommand: return command_name return ctx.invoked_with - def get_command_signature(self, command): + def get_command_signature(self, command: Command[Any, ..., Any]) -> str: """Retrieves the signature portion of the help page. Parameters @@ -418,14 +461,14 @@ class HelpCommand: The signature for the command. """ - parent = command.parent + parent: Optional[Group[Any, ..., Any]] = command.parent # type: ignore - the parent will be a Group entries = [] while parent is not None: if not parent.signature or parent.invoke_without_command: entries.append(parent.name) else: entries.append(parent.name + ' ' + parent.signature) - parent = parent.parent + parent = parent.parent # type: ignore parent_sig = ' '.join(reversed(entries)) if len(command.aliases) > 0: @@ -439,7 +482,7 @@ class HelpCommand: return f'{self.context.clean_prefix}{alias} {command.signature}' - def remove_mentions(self, string): + def remove_mentions(self, string: str) -> str: """Removes mentions from the string to prevent abuse. This includes ``@everyone``, ``@here``, member mentions and role mentions. @@ -450,13 +493,13 @@ class HelpCommand: The string with mentions removed. """ - def replace(obj, *, transforms=self.MENTION_TRANSFORMS): + def replace(obj: re.Match, *, transforms: Dict[str, str] = self.MENTION_TRANSFORMS) -> str: return transforms.get(obj.group(0), '@invalid') return self.MENTION_PATTERN.sub(replace, string) @property - def cog(self): + def cog(self) -> Optional[Cog]: """A property for retrieving or setting the cog for the help command. When a cog is set for the help command, it is as-if the help command @@ -473,7 +516,7 @@ class HelpCommand: return self._command_impl.cog @cog.setter - def cog(self, cog): + def cog(self, cog: Optional[Cog]) -> None: # Remove whatever cog is currently valid, if any self._command_impl._eject_cog() @@ -481,7 +524,7 @@ class HelpCommand: if cog is not None: self._command_impl._inject_into_cog(cog) - def command_not_found(self, string): + def command_not_found(self, string: str) -> str: """|maybecoro| A method called when a command is not found in the help command. @@ -502,7 +545,7 @@ class HelpCommand: """ return f'No command called "{string}" found.' - def subcommand_not_found(self, command, string): + def subcommand_not_found(self, command: Command[Any, ..., Any], string: str) -> str: """|maybecoro| A method called when a command did not have a subcommand requested in the help command. @@ -532,7 +575,13 @@ class HelpCommand: return f'Command "{command.qualified_name}" has no subcommand named {string}' return f'Command "{command.qualified_name}" has no subcommands.' - async def filter_commands(self, commands, *, sort=False, key=None): + async def filter_commands( + self, + commands: Iterable[Command[Any, ..., Any]], + *, + sort: bool = False, + key: Optional[Callable[[Command[Any, ..., Any]], Any]] = None, + ) -> List[Command[Any, ..., Any]]: """|coro| Returns a filtered list of commands and optionally sorts them. @@ -546,7 +595,7 @@ class HelpCommand: An iterable of commands that are getting filtered. sort: :class:`bool` Whether to sort the result. - key: Optional[Callable[:class:`Command`, Any]] + key: Optional[Callable[[:class:`Command`], Any]] An optional key function to pass to :func:`py:sorted` that takes a :class:`Command` as its sole parameter. If ``sort`` is passed as ``True`` then this will default as the command name. @@ -565,14 +614,14 @@ class HelpCommand: if self.verify_checks is False: # if we do not need to verify the checks then we can just # run it straight through normally without using await. - return sorted(iterator, key=key) if sort else list(iterator) + return sorted(iterator, key=key) if sort else list(iterator) # type: ignore - the key shouldn't be None if self.verify_checks is None and not self.context.guild: # if verify_checks is None and we're in a DM, don't verify - return sorted(iterator, key=key) if sort else list(iterator) + return sorted(iterator, key=key) if sort else list(iterator) # type: ignore # if we're here then we need to check every command if it can run - async def predicate(cmd): + async def predicate(cmd: Command[Any, ..., Any]) -> bool: try: return await cmd.can_run(self.context) except CommandError: @@ -588,7 +637,7 @@ class HelpCommand: ret.sort(key=key) return ret - def get_max_size(self, commands): + def get_max_size(self, commands: Sequence[Command[Any, ..., Any]]) -> int: """Returns the largest name length of the specified command list. Parameters @@ -605,7 +654,7 @@ class HelpCommand: as_lengths = (discord.utils._string_width(c.name) for c in commands) return max(as_lengths, default=0) - def get_destination(self): + def get_destination(self) -> discord.abc.MessageableChannel: """Returns the :class:`~discord.abc.Messageable` where the help command will be output. You can override this method to customise the behaviour. @@ -619,7 +668,7 @@ class HelpCommand: """ return self.context.channel - async def send_error_message(self, error): + async def send_error_message(self, error: str) -> None: """|coro| Handles the implementation when an error happens in the help command. @@ -644,7 +693,7 @@ class HelpCommand: await destination.send(error) @_not_overridden - async def on_help_command_error(self, ctx, error): + async def on_help_command_error(self, ctx: Context[BotT], error: CommandError) -> None: """|coro| The help command's error handler, as specified by :ref:`ext_commands_error_handler`. @@ -664,7 +713,7 @@ class HelpCommand: """ pass - async def send_bot_help(self, mapping): + async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, ..., Any]]]) -> None: """|coro| Handles the implementation of the bot command page in the help command. @@ -693,7 +742,7 @@ class HelpCommand: """ return None - async def send_cog_help(self, cog): + async def send_cog_help(self, cog: Cog) -> None: """|coro| Handles the implementation of the cog page in the help command. @@ -721,7 +770,7 @@ class HelpCommand: """ return None - async def send_group_help(self, group): + async def send_group_help(self, group: Group[Any, ..., Any]) -> None: """|coro| Handles the implementation of the group page in the help command. @@ -749,7 +798,7 @@ class HelpCommand: """ return None - async def send_command_help(self, command): + async def send_command_help(self, command: Command[Any, ..., Any]) -> None: """|coro| Handles the implementation of the single command page in the help command. @@ -787,7 +836,7 @@ class HelpCommand: """ return None - async def prepare_help_command(self, ctx, command=None): + async def prepare_help_command(self, ctx: Context[BotT], command: Optional[str] = None) -> None: """|coro| A low level method that can be used to prepare the help command @@ -811,7 +860,7 @@ class HelpCommand: """ pass - async def command_callback(self, ctx, *, command=None): + async def command_callback(self, ctx: Context[BotT], *, command: Optional[str] = None) -> None: """|coro| The actual implementation of the help command. @@ -856,7 +905,7 @@ class HelpCommand: for key in keys[1:]: try: - found = cmd.all_commands.get(key) + found = cmd.all_commands.get(key) # type: ignore except AttributeError: string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key)) return await self.send_error_message(string) @@ -908,28 +957,28 @@ class DefaultHelpCommand(HelpCommand): The paginator used to paginate the help command output. """ - def __init__(self, **options): - self.width = options.pop('width', 80) - self.indent = options.pop('indent', 2) - self.sort_commands = options.pop('sort_commands', True) - self.dm_help = options.pop('dm_help', False) - self.dm_help_threshold = options.pop('dm_help_threshold', 1000) - self.commands_heading = options.pop('commands_heading', "Commands:") - self.no_category = options.pop('no_category', 'No Category') - self.paginator = options.pop('paginator', None) + def __init__(self, **options: Any) -> None: + self.width: int = options.pop('width', 80) + self.indent: int = options.pop('indent', 2) + self.sort_commands: bool = options.pop('sort_commands', True) + self.dm_help: bool = options.pop('dm_help', False) + self.dm_help_threshold: int = options.pop('dm_help_threshold', 1000) + self.commands_heading: str = options.pop('commands_heading', "Commands:") + self.no_category: str = options.pop('no_category', 'No Category') + self.paginator: Paginator = options.pop('paginator', None) if self.paginator is None: - self.paginator = Paginator() + self.paginator: Paginator = Paginator() super().__init__(**options) - def shorten_text(self, text): + def shorten_text(self, text: str) -> str: """:class:`str`: Shortens text to fit into the :attr:`width`.""" if len(text) > self.width: return text[: self.width - 3].rstrip() + '...' return text - def get_ending_note(self): + def get_ending_note(self) -> str: """:class:`str`: Returns help command's ending note. This is mainly useful to override for i18n purposes.""" command_name = self.invoked_with return ( @@ -937,7 +986,9 @@ class DefaultHelpCommand(HelpCommand): f"You can also type {self.context.clean_prefix}{command_name} category for more info on a category." ) - def add_indented_commands(self, commands, *, heading, max_size=None): + def add_indented_commands( + self, commands: Sequence[Command[Any, ..., Any]], *, heading: str, max_size: Optional[int] = None + ) -> None: """Indents a list of commands after the specified heading. The formatting is added to the :attr:`paginator`. @@ -973,13 +1024,13 @@ class DefaultHelpCommand(HelpCommand): entry = f'{self.indent * " "}{name:<{width}} {command.short_doc}' self.paginator.add_line(self.shorten_text(entry)) - async def send_pages(self): + async def send_pages(self) -> None: """A helper utility to send the page output from :attr:`paginator` to the destination.""" destination = self.get_destination() for page in self.paginator.pages: await destination.send(page) - def add_command_formatting(self, command): + def add_command_formatting(self, command: Command[Any, ..., Any]) -> None: """A utility function to format the non-indented block of commands and groups. Parameters @@ -1002,7 +1053,7 @@ class DefaultHelpCommand(HelpCommand): self.paginator.add_line(line) self.paginator.add_line() - def get_destination(self): + def get_destination(self) -> discord.abc.Messageable: ctx = self.context if self.dm_help is True: return ctx.author @@ -1011,11 +1062,11 @@ class DefaultHelpCommand(HelpCommand): else: return ctx.channel - async def prepare_help_command(self, ctx, command): + async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None: self.paginator.clear() await super().prepare_help_command(ctx, command) - async def send_bot_help(self, mapping): + async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, ..., Any]]]) -> None: ctx = self.context bot = ctx.bot @@ -1045,12 +1096,12 @@ class DefaultHelpCommand(HelpCommand): await self.send_pages() - async def send_command_help(self, command): + async def send_command_help(self, command: Command[Any, ..., Any]) -> None: self.add_command_formatting(command) self.paginator.close_page() await self.send_pages() - async def send_group_help(self, group): + async def send_group_help(self, group: Group[Any, ..., Any]) -> None: self.add_command_formatting(group) filtered = await self.filter_commands(group.commands, sort=self.sort_commands) @@ -1064,7 +1115,7 @@ class DefaultHelpCommand(HelpCommand): await self.send_pages() - async def send_cog_help(self, cog): + async def send_cog_help(self, cog: Cog) -> None: if cog.description: self.paginator.add_line(cog.description, empty=True) @@ -1111,27 +1162,27 @@ class MinimalHelpCommand(HelpCommand): The paginator used to paginate the help command output. """ - def __init__(self, **options): - self.sort_commands = options.pop('sort_commands', True) - self.commands_heading = options.pop('commands_heading', "Commands") - self.dm_help = options.pop('dm_help', False) - self.dm_help_threshold = options.pop('dm_help_threshold', 1000) - self.aliases_heading = options.pop('aliases_heading', "Aliases:") - self.no_category = options.pop('no_category', 'No Category') - self.paginator = options.pop('paginator', None) + def __init__(self, **options: Any) -> None: + self.sort_commands: bool = options.pop('sort_commands', True) + self.commands_heading: str = options.pop('commands_heading', "Commands") + self.dm_help: bool = options.pop('dm_help', False) + self.dm_help_threshold: int = options.pop('dm_help_threshold', 1000) + self.aliases_heading: str = options.pop('aliases_heading', "Aliases:") + self.no_category: str = options.pop('no_category', 'No Category') + self.paginator: Paginator = options.pop('paginator', None) if self.paginator is None: - self.paginator = Paginator(suffix=None, prefix=None) + self.paginator: Paginator = Paginator(suffix=None, prefix=None) super().__init__(**options) - async def send_pages(self): + async def send_pages(self) -> None: """A helper utility to send the page output from :attr:`paginator` to the destination.""" destination = self.get_destination() for page in self.paginator.pages: await destination.send(page) - def get_opening_note(self): + def get_opening_note(self) -> str: """Returns help command's opening note. This is mainly useful to override for i18n purposes. The default implementation returns :: @@ -1150,10 +1201,10 @@ class MinimalHelpCommand(HelpCommand): f"You can also use `{self.context.clean_prefix}{command_name} [category]` for more info on a category." ) - def get_command_signature(self, command): + def get_command_signature(self, command: Command[Any, ..., Any]) -> str: return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}' - def get_ending_note(self): + def get_ending_note(self) -> str: """Return the help command's ending note. This is mainly useful to override for i18n purposes. The default implementation does nothing. @@ -1163,9 +1214,9 @@ class MinimalHelpCommand(HelpCommand): :class:`str` The help command ending note. """ - return None + return '' - def add_bot_commands_formatting(self, commands, heading): + def add_bot_commands_formatting(self, commands: Sequence[Command[Any, ..., Any]], heading: str) -> None: """Adds the minified bot heading with commands to the output. The formatting should be added to the :attr:`paginator`. @@ -1186,7 +1237,7 @@ class MinimalHelpCommand(HelpCommand): self.paginator.add_line(f'__**{heading}**__') self.paginator.add_line(joined) - def add_subcommand_formatting(self, command): + def add_subcommand_formatting(self, command: Command[Any, ..., Any]) -> None: """Adds formatting information on a subcommand. The formatting should be added to the :attr:`paginator`. @@ -1202,7 +1253,7 @@ class MinimalHelpCommand(HelpCommand): fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}' self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc)) - def add_aliases_formatting(self, aliases): + def add_aliases_formatting(self, aliases: Sequence[str]) -> None: """Adds the formatting information on a command's aliases. The formatting should be added to the :attr:`paginator`. @@ -1219,7 +1270,7 @@ class MinimalHelpCommand(HelpCommand): """ self.paginator.add_line(f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True) - def add_command_formatting(self, command): + def add_command_formatting(self, command: Command[Any, ..., Any]) -> None: """A utility function to format commands and groups. Parameters @@ -1246,7 +1297,7 @@ class MinimalHelpCommand(HelpCommand): self.paginator.add_line(line) self.paginator.add_line() - def get_destination(self): + def get_destination(self) -> discord.abc.Messageable: ctx = self.context if self.dm_help is True: return ctx.author @@ -1255,11 +1306,11 @@ class MinimalHelpCommand(HelpCommand): else: return ctx.channel - async def prepare_help_command(self, ctx, command): + async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None: self.paginator.clear() await super().prepare_help_command(ctx, command) - async def send_bot_help(self, mapping): + async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, ..., Any]]]) -> None: ctx = self.context bot = ctx.bot @@ -1272,7 +1323,7 @@ class MinimalHelpCommand(HelpCommand): no_category = f'\u200b{self.no_category}' - def get_category(command, *, no_category=no_category): + def get_category(command: Command[Any, ..., Any], *, no_category: str = no_category) -> str: cog = command.cog return cog.qualified_name if cog is not None else no_category @@ -1290,7 +1341,7 @@ class MinimalHelpCommand(HelpCommand): await self.send_pages() - async def send_cog_help(self, cog): + async def send_cog_help(self, cog: Cog) -> None: bot = self.context.bot if bot.description: self.paginator.add_line(bot.description, empty=True) @@ -1315,7 +1366,7 @@ class MinimalHelpCommand(HelpCommand): await self.send_pages() - async def send_group_help(self, group): + async def send_group_help(self, group: Group[Any, ..., Any]) -> None: self.add_command_formatting(group) filtered = await self.filter_commands(group.commands, sort=self.sort_commands) @@ -1335,7 +1386,7 @@ class MinimalHelpCommand(HelpCommand): await self.send_pages() - async def send_command_help(self, command): + async def send_command_help(self, command: Command[Any, ..., Any]) -> None: self.add_command_formatting(command) self.paginator.close_page() await self.send_pages() diff --git a/discord/ext/commands/view.py b/discord/ext/commands/view.py index b86298822..96d086811 100644 --- a/discord/ext/commands/view.py +++ b/discord/ext/commands/view.py @@ -21,6 +21,11 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + +from __future__ import annotations + +from typing import Optional + from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError # map from opening quotes to closing quotes @@ -47,24 +52,24 @@ _all_quotes = set(_quotes.keys()) | set(_quotes.values()) class StringView: - def __init__(self, buffer): - self.index = 0 - self.buffer = buffer - self.end = len(buffer) + def __init__(self, buffer: str) -> None: + self.index: int = 0 + self.buffer: str = buffer + self.end: int = len(buffer) self.previous = 0 @property - def current(self): + def current(self) -> Optional[str]: return None if self.eof else self.buffer[self.index] @property - def eof(self): + def eof(self) -> bool: return self.index >= self.end - def undo(self): + def undo(self) -> None: self.index = self.previous - def skip_ws(self): + def skip_ws(self) -> bool: pos = 0 while not self.eof: try: @@ -79,7 +84,7 @@ class StringView: self.index += pos return self.previous != self.index - def skip_string(self, string): + def skip_string(self, string: str) -> bool: strlen = len(string) if self.buffer[self.index : self.index + strlen] == string: self.previous = self.index @@ -87,19 +92,19 @@ class StringView: return True return False - def read_rest(self): + def read_rest(self) -> str: result = self.buffer[self.index :] self.previous = self.index self.index = self.end return result - def read(self, n): + def read(self, n: int) -> str: result = self.buffer[self.index : self.index + n] self.previous = self.index self.index += n return result - def get(self): + def get(self) -> Optional[str]: try: result = self.buffer[self.index + 1] except IndexError: @@ -109,7 +114,7 @@ class StringView: self.index += 1 return result - def get_word(self): + def get_word(self) -> str: pos = 0 while not self.eof: try: @@ -119,12 +124,12 @@ class StringView: pos += 1 except IndexError: break - self.previous = self.index + self.previous: int = self.index result = self.buffer[self.index : self.index + pos] self.index += pos return result - def get_quoted_word(self): + def get_quoted_word(self) -> Optional[str]: current = self.current if current is None: return None @@ -187,5 +192,5 @@ class StringView: result.append(current) - def __repr__(self): + def __repr__(self) -> str: return f'' diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 32fd9009b..9bd101a45 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -110,15 +110,15 @@ class SleepHandle: __slots__ = ('future', 'loop', 'handle') def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None: - self.loop = loop - self.future = future = loop.create_future() + self.loop: asyncio.AbstractEventLoop = loop + self.future: asyncio.Future[None] = loop.create_future() relative_delta = discord.utils.compute_timedelta(dt) - self.handle = loop.call_later(relative_delta, future.set_result, True) + self.handle = loop.call_later(relative_delta, self.future.set_result, True) def recalculate(self, dt: datetime.datetime) -> None: self.handle.cancel() relative_delta = discord.utils.compute_timedelta(dt) - self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) + self.handle: asyncio.TimerHandle = self.loop.call_later(relative_delta, self.future.set_result, True) def wait(self) -> asyncio.Future[Any]: return self.future diff --git a/discord/file.py b/discord/file.py index 4b060554b..e51744488 100644 --- a/discord/file.py +++ b/discord/file.py @@ -74,7 +74,7 @@ class File: def __init__( self, - fp: Union[str, bytes, os.PathLike, io.BufferedIOBase], + fp: Union[str, bytes, os.PathLike[Any], io.BufferedIOBase], filename: Optional[str] = None, *, spoiler: bool = False, diff --git a/discord/flags.py b/discord/flags.py index b776ddb57..beddb1f12 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -46,8 +46,8 @@ BF = TypeVar('BF', bound='BaseFlags') class flag_value: def __init__(self, func: Callable[[Any], int]): - self.flag = func(None) - self.__doc__ = func.__doc__ + self.flag: int = func(None) + self.__doc__: Optional[str] = func.__doc__ @overload def __get__(self, instance: None, owner: Type[BF]) -> Self: @@ -65,7 +65,7 @@ class flag_value: def __set__(self, instance: BaseFlags, value: bool) -> None: instance._set_flag(self.flag, value) - def __repr__(self): + def __repr__(self) -> str: return f'' @@ -73,8 +73,8 @@ class alias_flag_value(flag_value): pass -def fill_with_flags(*, inverted: bool = False): - def decorator(cls: Type[BF]): +def fill_with_flags(*, inverted: bool = False) -> Callable[[Type[BF]], Type[BF]]: + def decorator(cls: Type[BF]) -> Type[BF]: # fmt: off cls.VALID_FLAGS = { name: value.flag @@ -116,10 +116,10 @@ class BaseFlags: self.value = value return self - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, self.__class__) and self.value == other.value - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self) -> int: @@ -504,8 +504,8 @@ class Intents(BaseFlags): __slots__ = () - def __init__(self, **kwargs: bool): - self.value = self.DEFAULT_VALUE + def __init__(self, **kwargs: bool) -> None: + self.value: int = self.DEFAULT_VALUE for key, value in kwargs.items(): if key not in self.VALID_FLAGS: raise TypeError(f'{key!r} is not a valid flag name.') @@ -1005,7 +1005,7 @@ class MemberCacheFlags(BaseFlags): def __init__(self, **kwargs: bool): bits = max(self.VALID_FLAGS.values()).bit_length() - self.value = (1 << bits) - 1 + self.value: int = (1 << bits) - 1 for key, value in kwargs.items(): if key not in self.VALID_FLAGS: raise TypeError(f'{key!r} is not a valid flag name.') diff --git a/discord/gateway.py b/discord/gateway.py index 978ff0fc9..2de1f0508 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -54,6 +54,8 @@ __all__ = ( ) if TYPE_CHECKING: + from typing_extensions import Self + from .client import Client from .state import ConnectionState from .voice_client import VoiceClient @@ -62,10 +64,10 @@ if TYPE_CHECKING: class ReconnectWebSocket(Exception): """Signals to safely reconnect the websocket.""" - def __init__(self, shard_id, *, resume=True): - self.shard_id = shard_id - self.resume = resume - self.op = 'RESUME' if resume else 'IDENTIFY' + def __init__(self, shard_id: Optional[int], *, resume: bool = True) -> None: + self.shard_id: Optional[int] = shard_id + self.resume: bool = resume + self.op: str = 'RESUME' if resume else 'IDENTIFY' class WebSocketClosure(Exception): @@ -225,7 +227,7 @@ class VoiceKeepAliveHandler(KeepAliveHandler): ack_time = time.perf_counter() self._last_ack = ack_time self._last_recv = ack_time - self.latency = ack_time - self._last_send + self.latency: float = ack_time - self._last_send self.recent_ack_latencies.append(self.latency) @@ -339,7 +341,7 @@ class DiscordWebSocket: @classmethod async def from_client( - cls: Type[DWS], + cls, client: Client, *, initial: bool = False, @@ -348,7 +350,7 @@ class DiscordWebSocket: session: Optional[str] = None, sequence: Optional[int] = None, resume: bool = False, - ) -> DWS: + ) -> Self: """Creates a main websocket for Discord from a :class:`Client`. This is for internal use only. @@ -821,11 +823,11 @@ class DiscordVoiceWebSocket: *, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None, ) -> None: - self.ws = socket - self.loop = loop - self._keep_alive = None - self._close_code = None - self.secret_key = None + self.ws: aiohttp.ClientWebSocketResponse = socket + self.loop: asyncio.AbstractEventLoop = loop + self._keep_alive: Optional[VoiceKeepAliveHandler] = None + self._close_code: Optional[int] = None + self.secret_key: Optional[str] = None if hook: self._hook = hook # type: ignore - type-checker doesn't like overriding methods @@ -864,7 +866,9 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) @classmethod - async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume=False, hook=None) -> DVWS: + async def from_client( + cls, client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None + ) -> Self: """Creates a voice websocket for the :class:`VoiceClient`.""" gateway = 'wss://' + client.endpoint + '/?v=4' http = client._state.http diff --git a/discord/guild.py b/discord/guild.py index 9c713c06a..9cc56f5fe 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -123,6 +123,7 @@ if TYPE_CHECKING: ) from .types.integration import IntegrationType from .types.snowflake import SnowflakeList + from .types.widget import EditWidgetSettings VocalGuildChannel = Union[VoiceChannel, StageChannel] GuildChannel = Union[VocalGuildChannel, TextChannel, CategoryChannel, StoreChannel] @@ -3379,7 +3380,7 @@ class Guild(Hashable): HTTPException Editing the widget failed. """ - payload = {} + payload: EditWidgetSettings = {} if channel is not MISSING: payload['channel_id'] = None if channel is None else channel.id if enabled is not MISSING: @@ -3492,7 +3493,7 @@ class Guild(Hashable): async def change_voice_state( self, *, channel: Optional[abc.Snowflake], self_mute: bool = False, self_deaf: bool = False - ): + ) -> None: """|coro| Changes client's voice state in the guild. diff --git a/discord/http.py b/discord/http.py index 645111918..200eda52f 100644 --- a/discord/http.py +++ b/discord/http.py @@ -76,12 +76,9 @@ if TYPE_CHECKING: audit_log, channel, command, - components, emoji, - embed, guild, integration, - interactions, invite, member, message, @@ -92,7 +89,6 @@ if TYPE_CHECKING: channel, widget, threads, - voice, scheduled_event, sticker, ) @@ -122,7 +118,7 @@ class MultipartParameters(NamedTuple): multipart: Optional[List[Dict[str, Any]]] files: Optional[List[File]] - def __enter__(self): + def __enter__(self) -> Self: return self def __exit__( @@ -577,7 +573,7 @@ class HTTPClient: return self.request(Route('POST', '/users/{user_id}/channels', user_id=user_id), json=payload) - def leave_group(self, channel_id) -> Response[None]: + def leave_group(self, channel_id: Snowflake) -> Response[None]: return self.request(Route('DELETE', '/channels/{channel_id}', channel_id=channel_id)) # Message management @@ -1160,7 +1156,7 @@ class HTTPClient: def sync_template(self, guild_id: Snowflake, code: str) -> Response[template.Template]: return self.request(Route('PUT', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code)) - def edit_template(self, guild_id: Snowflake, code: str, payload) -> Response[template.Template]: + def edit_template(self, guild_id: Snowflake, code: str, payload: Dict[str, Any]) -> Response[template.Template]: valid_keys = ( 'name', 'description', @@ -1420,7 +1416,9 @@ class HTTPClient: def get_widget(self, guild_id: Snowflake) -> Response[widget.Widget]: return self.request(Route('GET', '/guilds/{guild_id}/widget.json', guild_id=guild_id)) - def edit_widget(self, guild_id: Snowflake, payload, reason: Optional[str] = None) -> Response[widget.WidgetSettings]: + def edit_widget( + self, guild_id: Snowflake, payload: widget.EditWidgetSettings, reason: Optional[str] = None + ) -> Response[widget.WidgetSettings]: return self.request(Route('PATCH', '/guilds/{guild_id}/widget', guild_id=guild_id), json=payload, reason=reason) # Invite management @@ -1812,7 +1810,9 @@ class HTTPClient: ) return self.request(r) - def upsert_global_command(self, application_id: Snowflake, payload) -> Response[command.ApplicationCommand]: + def upsert_global_command( + self, application_id: Snowflake, payload: command.ApplicationCommand + ) -> Response[command.ApplicationCommand]: r = Route('POST', '/applications/{application_id}/commands', application_id=application_id) return self.request(r, json=payload) @@ -1845,7 +1845,9 @@ class HTTPClient: ) return self.request(r) - def bulk_upsert_global_commands(self, application_id: Snowflake, payload) -> Response[List[command.ApplicationCommand]]: + def bulk_upsert_global_commands( + self, application_id: Snowflake, payload: List[Dict[str, Any]] + ) -> Response[List[command.ApplicationCommand]]: r = Route('PUT', '/applications/{application_id}/commands', application_id=application_id) return self.request(r, json=payload) diff --git a/discord/integrations.py b/discord/integrations.py index 406b52633..d3a2ec244 100644 --- a/discord/integrations.py +++ b/discord/integrations.py @@ -39,6 +39,9 @@ __all__ = ( ) if TYPE_CHECKING: + from .guild import Guild + from .role import Role + from .state import ConnectionState from .types.integration import ( IntegrationAccount as IntegrationAccountPayload, Integration as IntegrationPayload, @@ -47,8 +50,6 @@ if TYPE_CHECKING: IntegrationType, IntegrationApplication as IntegrationApplicationPayload, ) - from .guild import Guild - from .role import Role class IntegrationAccount: @@ -109,11 +110,11 @@ class Integration: ) def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None: - self.guild = guild - self._state = guild._state + self.guild: Guild = guild + self._state: ConnectionState = guild._state self._from_data(data) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>" def _from_data(self, data: IntegrationPayload) -> None: @@ -123,7 +124,7 @@ class Integration: self.account: IntegrationAccount = IntegrationAccount(data['account']) user = data.get('user') - self.user = User(state=self._state, data=user) if user else None + self.user: Optional[User] = User(state=self._state, data=user) if user else None self.enabled: bool = data['enabled'] async def delete(self, *, reason: Optional[str] = None) -> None: @@ -319,7 +320,7 @@ class IntegrationApplication: 'user', ) - def __init__(self, *, data: IntegrationApplicationPayload, state): + def __init__(self, *, data: IntegrationApplicationPayload, state: ConnectionState) -> None: self.id: int = int(data['id']) self.name: str = data['name'] self.icon: Optional[str] = data['icon'] @@ -358,7 +359,7 @@ class BotIntegration(Integration): def _from_data(self, data: BotIntegrationPayload) -> None: super()._from_data(data) - self.application = IntegrationApplication(data=data['application'], state=self._state) + self.application: IntegrationApplication = IntegrationApplication(data=data['application'], state=self._state) def _integration_factory(value: str) -> Tuple[Type[Integration], str]: diff --git a/discord/interactions.py b/discord/interactions.py index ed4ec24e8..9752063d9 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -54,6 +54,9 @@ if TYPE_CHECKING: Interaction as InteractionPayload, InteractionData, ) + from .types.webhook import ( + Webhook as WebhookPayload, + ) from .client import Client from .guild import Guild from .state import ConnectionState @@ -229,7 +232,7 @@ class Interaction: @utils.cached_slot_property('_cs_followup') def followup(self) -> Webhook: """:class:`Webhook`: Returns the follow up webhook for follow up interactions.""" - payload = { + payload: WebhookPayload = { 'id': self.application_id, 'type': 3, 'token': self.token, @@ -703,7 +706,7 @@ class InteractionResponse: self._responded = True - async def send_modal(self, modal: Modal, /): + async def send_modal(self, modal: Modal, /) -> None: """|coro| Responds to this interaction by sending a modal. diff --git a/discord/invite.py b/discord/invite.py index 95f8a92fa..d640a05c5 100644 --- a/discord/invite.py +++ b/discord/invite.py @@ -456,7 +456,7 @@ class Invite(Hashable): guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id) channel_id = int(data['channel_id']) if guild is not None: - channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore + channel = guild.get_channel(channel_id) or Object(id=channel_id) else: guild = Object(id=guild_id) if guild_id is not None else None channel = Object(id=channel_id) @@ -539,7 +539,7 @@ class Invite(Hashable): return self - async def delete(self, *, reason: Optional[str] = None): + async def delete(self, *, reason: Optional[str] = None) -> None: """|coro| Revokes the instant invite. diff --git a/discord/member.py b/discord/member.py index fa5f28f67..284cc3fcc 100644 --- a/discord/member.py +++ b/discord/member.py @@ -27,9 +27,8 @@ from __future__ import annotations import datetime import inspect import itertools -import sys from operator import attrgetter -from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union +from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, Type import discord.abc @@ -207,7 +206,7 @@ class _ClientStatus: return self -def flatten_user(cls): +def flatten_user(cls: Any) -> Type[Member]: for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()): # ignore private/special methods if attr.startswith('_'): @@ -333,7 +332,7 @@ class Member(discord.abc.Messageable, _UserTag): default_avatar: Asset avatar: Optional[Asset] dm_channel: Optional[DMChannel] - create_dm = User.create_dm + create_dm: Callable[[], Coroutine[Any, Any, DMChannel]] mutual_guilds: List[Guild] public_flags: PublicUserFlags banner: Optional[Asset] @@ -369,10 +368,10 @@ class Member(discord.abc.Messageable, _UserTag): f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>' ) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, _UserTag) and other.id == self.id - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self) -> int: @@ -425,7 +424,7 @@ class Member(discord.abc.Messageable, _UserTag): self._user = member._user return self - async def _get_channel(self): + async def _get_channel(self) -> DMChannel: ch = await self.create_dm() return ch diff --git a/discord/mentions.py b/discord/mentions.py index 25e9c29f2..6b7a66dec 100644 --- a/discord/mentions.py +++ b/discord/mentions.py @@ -92,10 +92,10 @@ class AllowedMentions: roles: Union[bool, List[Snowflake]] = default, replied_user: bool = default, ): - self.everyone = everyone - self.users = users - self.roles = roles - self.replied_user = replied_user + self.everyone: bool = everyone + self.users: Union[bool, List[Snowflake]] = users + self.roles: Union[bool, List[Snowflake]] = roles + self.replied_user: bool = replied_user @classmethod def all(cls) -> Self: diff --git a/discord/message.py b/discord/message.py index 7832d2095..ed195475f 100644 --- a/discord/message.py +++ b/discord/message.py @@ -40,6 +40,7 @@ from typing import ( Tuple, ClassVar, Optional, + Type, overload, ) @@ -71,7 +72,6 @@ if TYPE_CHECKING: MessageReference as MessageReferencePayload, MessageApplication as MessageApplicationPayload, MessageActivity as MessageActivityPayload, - Reaction as ReactionPayload, ) from .types.components import Component as ComponentPayload @@ -87,7 +87,7 @@ if TYPE_CHECKING: from .abc import GuildChannel, PartialMessageableChannel, MessageableChannel from .components import Component from .state import ConnectionState - from .channel import TextChannel, GroupChannel, DMChannel + from .channel import TextChannel from .mentions import AllowedMentions from .user import User from .role import Role @@ -95,6 +95,7 @@ if TYPE_CHECKING: EmojiInputType = Union[Emoji, PartialEmoji, str] + __all__ = ( 'Attachment', 'Message', @@ -104,7 +105,7 @@ __all__ = ( ) -def convert_emoji_reaction(emoji): +def convert_emoji_reaction(emoji: Union[EmojiInputType, Reaction]) -> str: if isinstance(emoji, Reaction): emoji = emoji.emoji @@ -216,7 +217,7 @@ class Attachment(Hashable): async def save( self, - fp: Union[io.BufferedIOBase, PathLike], + fp: Union[io.BufferedIOBase, PathLike[Any]], *, seek_begin: bool = True, use_cached: bool = False, @@ -510,7 +511,7 @@ class MessageReference: to_message_reference_dict = to_dict -def flatten_handlers(cls): +def flatten_handlers(cls: Type[Message]) -> Type[Message]: prefix = len('_handle_') handlers = [ (key[prefix:], value) @@ -1036,7 +1037,7 @@ class Message(Hashable): ) @utils.cached_slot_property('_cs_system_content') - def system_content(self): + def system_content(self) -> Optional[str]: r""":class:`str`: A property that returns the content that is rendered regardless of the :attr:`Message.type`. @@ -1657,7 +1658,7 @@ class Message(Hashable): ) return Thread(guild=self.guild, state=self._state, data=data) - async def reply(self, content: Optional[str] = None, **kwargs) -> Message: + async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message: """|coro| A shortcut method to :meth:`.abc.Messageable.send` to reply to the @@ -1798,7 +1799,7 @@ class PartialMessage(Hashable): # Also needed for duck typing purposes # n.b. not exposed - pinned = property(None, lambda x, y: None) + pinned: Any = property(None, lambda x, y: None) def __repr__(self) -> str: return f'' diff --git a/discord/opus.py b/discord/opus.py index a417756e5..33641554e 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -363,7 +363,7 @@ class Encoder(_OpusStruct): _lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0) def set_expected_packet_loss_percent(self, percentage: float) -> None: - _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore + _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) def encode(self, pcm: bytes, frame_size: int) -> bytes: max_data_bytes = len(pcm) @@ -373,8 +373,7 @@ class Encoder(_OpusStruct): ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes) - # array can be initialized with bytes but mypy doesn't know - return array.array('b', data[:ret]).tobytes() # type: ignore + return array.array('b', data[:ret]).tobytes() class Decoder(_OpusStruct): diff --git a/discord/partial_emoji.py b/discord/partial_emoji.py index c95d5ed31..cf65efb3b 100644 --- a/discord/partial_emoji.py +++ b/discord/partial_emoji.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: from .state import ConnectionState from datetime import datetime from .types.message import PartialEmoji as PartialEmojiPayload + from .types.activity import ActivityEmoji class _EmojiTag: @@ -99,13 +100,13 @@ class PartialEmoji(_EmojiTag, AssetMixin): id: Optional[int] def __init__(self, *, name: str, animated: bool = False, id: Optional[int] = None): - self.animated = animated - self.name = name - self.id = id + self.animated: bool = animated + self.name: str = name + self.id: Optional[int] = id self._state: Optional[ConnectionState] = None @classmethod - def from_dict(cls, data: Union[PartialEmojiPayload, Dict[str, Any]]) -> Self: + def from_dict(cls, data: Union[PartialEmojiPayload, ActivityEmoji, Dict[str, Any]]) -> Self: return cls( animated=data.get('animated', False), id=utils._get_as_snowflake(data, 'id'), @@ -178,10 +179,10 @@ class PartialEmoji(_EmojiTag, AssetMixin): return f'' return f'<:{self.name}:{self.id}>' - def __repr__(self): + def __repr__(self) -> str: return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>' - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if self.is_unicode_emoji(): return isinstance(other, PartialEmoji) and self.name == other.name @@ -189,7 +190,7 @@ class PartialEmoji(_EmojiTag, AssetMixin): return self.id == other.id return False - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self) -> int: diff --git a/discord/permissions.py b/discord/permissions.py index 5c81f8445..e2b053b29 100644 --- a/discord/permissions.py +++ b/discord/permissions.py @@ -276,7 +276,7 @@ class Permissions(BaseFlags): # So 0000 OP2 0101 -> 0101 # The OP is base & ~denied. # The OP2 is base | allowed. - self.value = (self.value & ~deny) | allow + self.value: int = (self.value & ~deny) | allow @flag_value def create_instant_invite(self) -> int: @@ -691,7 +691,7 @@ class PermissionOverwrite: setattr(self, key, value) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, PermissionOverwrite) and self._values == other._values def _set(self, key: str, value: Optional[bool]) -> None: diff --git a/discord/player.py b/discord/player.py index cf215756b..76940c31e 100644 --- a/discord/player.py +++ b/discord/player.py @@ -365,12 +365,11 @@ class FFmpegOpusAudio(FFmpegAudio): bitrate: Optional[int] = None, codec: Optional[str] = None, executable: str = 'ffmpeg', - pipe=False, - stderr=None, - before_options=None, - options=None, + pipe: bool = False, + stderr: Optional[IO[bytes]] = None, + before_options: Optional[str] = None, + options: Optional[str] = None, ) -> None: - args = [] subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr} @@ -635,7 +634,13 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]): class AudioPlayer(threading.Thread): DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0 - def __init__(self, source: AudioSource, client: VoiceClient, *, after=None): + def __init__( + self, + source: AudioSource, + client: VoiceClient, + *, + after: Optional[Callable[[Optional[Exception]], Any]] = None, + ) -> None: threading.Thread.__init__(self) self.daemon: bool = True self.source: AudioSource = source @@ -724,8 +729,8 @@ class AudioPlayer(threading.Thread): self._speak(SpeakingState.none) def resume(self, *, update_speaking: bool = True) -> None: - self.loops = 0 - self._start = time.perf_counter() + self.loops: int = 0 + self._start: float = time.perf_counter() self._resumed.set() if update_speaking: self._speak(SpeakingState.voice) diff --git a/discord/reaction.py b/discord/reaction.py index 128986385..6bf8a9f45 100644 --- a/discord/reaction.py +++ b/discord/reaction.py @@ -94,10 +94,10 @@ class Reaction: """:class:`bool`: If this is a custom emoji.""" return not isinstance(self.emoji, str) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, self.__class__) and other.emoji == self.emoji - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: if isinstance(other, self.__class__): return other.emoji != self.emoji return True diff --git a/discord/role.py b/discord/role.py index 0a8f49627..e503b7d08 100644 --- a/discord/role.py +++ b/discord/role.py @@ -211,7 +211,7 @@ class Role(Hashable): def __repr__(self) -> str: return f'' - def __lt__(self, other: Any) -> bool: + def __lt__(self, other: object) -> bool: if not isinstance(other, Role) or not isinstance(self, Role): return NotImplemented @@ -241,7 +241,7 @@ class Role(Hashable): def __gt__(self, other: Any) -> bool: return Role.__lt__(other, self) - def __ge__(self, other: Any) -> bool: + def __ge__(self, other: object) -> bool: r = Role.__lt__(self, other) if r is NotImplemented: return NotImplemented diff --git a/discord/scheduled_event.py b/discord/scheduled_event.py index 1fd93bdc9..059e26472 100644 --- a/discord/scheduled_event.py +++ b/discord/scheduled_event.py @@ -132,7 +132,7 @@ class ScheduledEvent(Hashable): self.guild_id: int = int(data['guild_id']) self.name: str = data['name'] self.description: Optional[str] = data.get('description') - self.entity_type = try_enum(EntityType, data['entity_type']) + self.entity_type: EntityType = try_enum(EntityType, data['entity_type']) self.entity_id: Optional[int] = _get_as_snowflake(data, 'entity_id') self.start_time: datetime = parse_time(data['scheduled_start_time']) self.privacy_level: PrivacyLevel = try_enum(PrivacyLevel, data['status']) @@ -153,7 +153,7 @@ class ScheduledEvent(Hashable): self.location: Optional[str] = data.get('location') if data else None @classmethod - def from_creation(cls, *, state: ConnectionState, data: GuildScheduledEventPayload): + def from_creation(cls, *, state: ConnectionState, data: GuildScheduledEventPayload) -> None: creator_id = data.get('creator_id') self = cls(state=state, data=data) if creator_id: @@ -180,7 +180,7 @@ class ScheduledEvent(Hashable): return self.guild.get_channel(self.channel_id) # type: ignore @property - def url(self): + def url(self) -> str: """:class:`str`: The url for the scheduled event.""" return f'https://discord.com/events/{self.guild_id}/{self.id}' diff --git a/discord/shard.py b/discord/shard.py index 59a350515..5592004d8 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -75,12 +75,12 @@ class EventItem: self.shard: Optional['Shard'] = shard self.error: Optional[Exception] = error - def __lt__(self, other: Any) -> bool: + def __lt__(self, other: object) -> bool: if not isinstance(other, EventItem): return NotImplemented return self.type < other.type - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, EventItem): return NotImplemented return self.type == other.type @@ -409,6 +409,7 @@ class AutoShardedClient(Client): async def launch_shards(self) -> None: if self.shard_count is None: + self.shard_count: int self.shard_count, gateway = await self.http.get_bot_gateway() else: gateway = await self.http.get_gateway() diff --git a/discord/stage_instance.py b/discord/stage_instance.py index fefc76e82..fa1c67304 100644 --- a/discord/stage_instance.py +++ b/discord/stage_instance.py @@ -97,11 +97,11 @@ class StageInstance(Hashable): ) def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None: - self._state = state - self.guild = guild + self._state: ConnectionState = state + self.guild: Guild = guild self._update(data) - def _update(self, data: StageInstancePayload): + def _update(self, data: StageInstancePayload) -> None: self.id: int = int(data['id']) self.channel_id: int = int(data['channel_id']) self.topic: str = data['topic'] diff --git a/discord/state.py b/discord/state.py index aa4e36eaf..88b9c5037 100644 --- a/discord/state.py +++ b/discord/state.py @@ -43,6 +43,8 @@ from typing import ( Sequence, Tuple, Deque, + Literal, + overload, ) import weakref import inspect @@ -88,7 +90,7 @@ if TYPE_CHECKING: from .types.activity import Activity as ActivityPayload from .types.channel import DMChannel as DMChannelPayload from .types.user import User as UserPayload, PartialUser as PartialUserPayload - from .types.emoji import Emoji as EmojiPayload + from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload from .types.sticker import GuildSticker as GuildStickerPayload from .types.guild import Guild as GuildPayload from .types.message import Message as MessagePayload, PartialMessage as PartialMessagePayload @@ -165,9 +167,9 @@ class ConnectionState: def __init__( self, *, - dispatch: Callable, - handlers: Dict[str, Callable], - hooks: Dict[str, Callable], + dispatch: Callable[..., Any], + handlers: Dict[str, Callable[..., Any]], + hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]], http: HTTPClient, **options: Any, ) -> None: @@ -178,9 +180,9 @@ class ConnectionState: if self.max_messages is not None and self.max_messages <= 0: self.max_messages = 1000 - self.dispatch: Callable = dispatch - self.handlers: Dict[str, Callable] = handlers - self.hooks: Dict[str, Callable] = hooks + self.dispatch: Callable[..., Any] = dispatch + self.handlers: Dict[str, Callable[..., Any]] = handlers + self.hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = hooks self.shard_count: Optional[int] = None self._ready_task: Optional[asyncio.Task] = None self.application_id: Optional[int] = utils._get_as_snowflake(options, 'application_id') @@ -245,6 +247,7 @@ class ConnectionState: if not intents.members or cache_flags._empty: self.store_user = self.store_user_no_intents # type: ignore - This reassignment is on purpose + self.parsers: Dict[str, Callable[[Any], None]] self.parsers = parsers = {} for attr, func in inspect.getmembers(self): if attr.startswith('parse_'): @@ -343,13 +346,13 @@ class ConnectionState: self._users[user_id] = user return user - def store_user_no_intents(self, data): + def store_user_no_intents(self, data: Union[UserPayload, PartialUserPayload]) -> User: return User(state=self, data=data) def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> User: return User(state=self, data=data) - def get_user(self, id): + def get_user(self, id: int) -> Optional[User]: return self._users.get(id) def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji: @@ -571,8 +574,7 @@ class ConnectionState: pass else: self.application_id = utils._get_as_snowflake(application, 'id') - # flags will always be present here - self.application_flags = ApplicationFlags._from_value(application['flags']) + self.application_flags: ApplicationFlags = ApplicationFlags._from_value(application['flags']) for guild_data in data['guilds']: self._add_guild_from_data(guild_data) # type: ignore @@ -743,7 +745,7 @@ class ConnectionState: self.dispatch('presence_update', old_member, member) - def parse_user_update(self, data: gw.UserUpdateEvent): + def parse_user_update(self, data: gw.UserUpdateEvent) -> None: if self.user: self.user._update(data) @@ -1050,7 +1052,7 @@ class ConnectionState: guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers'])) self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers) - def _get_create_guild(self, data): + def _get_create_guild(self, data: gw.GuildCreateEvent) -> Guild: if data.get('unavailable') is False: # GUILD_CREATE with unavailable in the response # usually means that the guild has become available @@ -1063,10 +1065,22 @@ class ConnectionState: return self._add_guild_from_data(data) - def is_guild_evicted(self, guild) -> bool: + def is_guild_evicted(self, guild: Guild) -> bool: return guild.id not in self._guilds - async def chunk_guild(self, guild, *, wait=True, cache=None): + @overload + async def chunk_guild(self, guild: Guild, *, wait: Literal[True] = ..., cache: Optional[bool] = ...) -> List[Member]: + ... + + @overload + async def chunk_guild( + self, guild: Guild, *, wait: Literal[False] = ..., cache: Optional[bool] = ... + ) -> asyncio.Future[List[Member]]: + ... + + async def chunk_guild( + self, guild: Guild, *, wait: bool = True, cache: Optional[bool] = None + ) -> Union[List[Member], asyncio.Future[List[Member]]]: cache = cache or self.member_cache_flags.joined request = self._chunk_requests.get(guild.id) if request is None: @@ -1445,16 +1459,19 @@ class ConnectionState: return channel.guild.get_member(user_id) return self.get_user(user_id) - def get_reaction_emoji(self, data) -> Union[Emoji, PartialEmoji]: + def get_reaction_emoji(self, data: PartialEmojiPayload) -> Union[Emoji, PartialEmoji, str]: emoji_id = utils._get_as_snowflake(data, 'id') if not emoji_id: - return data['name'] + # the name key will be a str + return data['name'] # type: ignore try: return self._emojis[emoji_id] except KeyError: - return PartialEmoji.with_state(self, animated=data.get('animated', False), id=emoji_id, name=data['name']) + return PartialEmoji.with_state( + self, animated=data.get('animated', False), id=emoji_id, name=data['name'] # type: ignore + ) def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmoji, str]: emoji_id = emoji.id @@ -1589,6 +1606,7 @@ class AutoShardedConnectionState(ConnectionState): if not hasattr(self, '_ready_state'): self._ready_state = asyncio.Queue() + self.user: Optional[ClientUser] self.user = user = ClientUser(state=self, data=data['user']) # self._users is a list of Users, we're setting a ClientUser self._users[user.id] = user # type: ignore @@ -1599,8 +1617,8 @@ class AutoShardedConnectionState(ConnectionState): except KeyError: pass else: - self.application_id = utils._get_as_snowflake(application, 'id') - self.application_flags = ApplicationFlags._from_value(application['flags']) + self.application_id: Optional[int] = utils._get_as_snowflake(application, 'id') + self.application_flags: ApplicationFlags = ApplicationFlags._from_value(application['flags']) for guild_data in data['guilds']: self._add_guild_from_data(guild_data) # type: ignore - _add_guild_from_data requires a complete Guild payload diff --git a/discord/sticker.py b/discord/sticker.py index 4ac1234f1..456b3d720 100644 --- a/discord/sticker.py +++ b/discord/sticker.py @@ -228,7 +228,7 @@ class StickerItem(_StickerTag): The retrieved sticker. """ data: StickerPayload = await self._state.http.get_sticker(self.id) - cls, _ = _sticker_factory(data['type']) # type: ignore + cls, _ = _sticker_factory(data['type']) return cls(state=self._state, data=data) diff --git a/discord/threads.py b/discord/threads.py index 0662feadd..cdfd1f482 100644 --- a/discord/threads.py +++ b/discord/threads.py @@ -41,6 +41,8 @@ __all__ = ( ) if TYPE_CHECKING: + from typing_extensions import Self + from .types.threads import ( Thread as ThreadPayload, ThreadMember as ThreadMemberPayload, @@ -147,13 +149,13 @@ class Thread(Messageable, Hashable): '_created_at', ) - def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload): + def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload) -> None: self._state: ConnectionState = state - self.guild = guild + self.guild: Guild = guild self._members: Dict[int, ThreadMember] = {} self._from_data(data) - async def _get_channel(self): + async def _get_channel(self) -> Self: return self def __repr__(self) -> str: @@ -166,17 +168,18 @@ class Thread(Messageable, Hashable): return self.name def _from_data(self, data: ThreadPayload): - self.id = int(data['id']) - self.parent_id = int(data['parent_id']) - self.owner_id = int(data['owner_id']) - self.name = data['name'] - self._type = try_enum(ChannelType, data['type']) - self.last_message_id = _get_as_snowflake(data, 'last_message_id') - self.slowmode_delay = data.get('rate_limit_per_user', 0) - self.message_count = data['message_count'] - self.member_count = data['member_count'] + self.id: int = int(data['id']) + self.parent_id: int = int(data['parent_id']) + self.owner_id: int = int(data['owner_id']) + self.name: str = data['name'] + self._type: ChannelType = try_enum(ChannelType, data['type']) + self.last_message_id: Optional[int] = _get_as_snowflake(data, 'last_message_id') + self.slowmode_delay: int = data.get('rate_limit_per_user', 0) + self.message_count: int = data['message_count'] + self.member_count: int = data['member_count'] self._unroll_metadata(data['thread_metadata']) + self.me: Optional[ThreadMember] try: member = data['member'] except KeyError: @@ -185,15 +188,15 @@ class Thread(Messageable, Hashable): self.me = ThreadMember(self, member) def _unroll_metadata(self, data: ThreadMetadata): - self.archived = data['archived'] - self.archiver_id = _get_as_snowflake(data, 'archiver_id') - self.auto_archive_duration = data['auto_archive_duration'] - self.archive_timestamp = parse_time(data['archive_timestamp']) - self.locked = data.get('locked', False) - self.invitable = data.get('invitable', True) - self._created_at = parse_time(data.get('create_timestamp')) - - def _update(self, data): + self.archived: bool = data['archived'] + self.archiver_id: Optional[int] = _get_as_snowflake(data, 'archiver_id') + self.auto_archive_duration: int = data['auto_archive_duration'] + self.archive_timestamp: datetime = parse_time(data['archive_timestamp']) + self.locked: bool = data.get('locked', False) + self.invitable: bool = data.get('invitable', True) + self._created_at: Optional[datetime] = parse_time(data.get('create_timestamp')) + + def _update(self, data: ThreadPayload) -> None: try: self.name = data['name'] except KeyError: @@ -602,7 +605,7 @@ class Thread(Messageable, Hashable): # The data payload will always be a Thread payload return Thread(data=data, state=self._state, guild=self.guild) # type: ignore - async def join(self): + async def join(self) -> None: """|coro| Joins this thread. @@ -619,7 +622,7 @@ class Thread(Messageable, Hashable): """ await self._state.http.join_thread(self.id) - async def leave(self): + async def leave(self) -> None: """|coro| Leaves this thread. @@ -631,7 +634,7 @@ class Thread(Messageable, Hashable): """ await self._state.http.leave_thread(self.id) - async def add_user(self, user: Snowflake, /): + async def add_user(self, user: Snowflake, /) -> None: """|coro| Adds a user to this thread. @@ -654,7 +657,7 @@ class Thread(Messageable, Hashable): """ await self._state.http.add_user_to_thread(self.id, user.id) - async def remove_user(self, user: Snowflake, /): + async def remove_user(self, user: Snowflake, /) -> None: """|coro| Removes a user from this thread. @@ -718,7 +721,7 @@ class Thread(Messageable, Hashable): members = await self._state.http.get_thread_members(self.id) return [ThreadMember(parent=self, data=data) for data in members] - async def delete(self): + async def delete(self) -> None: """|coro| Deletes this thread. @@ -806,28 +809,28 @@ class ThreadMember(Hashable): 'parent', ) - def __init__(self, parent: Thread, data: ThreadMemberPayload): - self.parent = parent - self._state = parent._state + def __init__(self, parent: Thread, data: ThreadMemberPayload) -> None: + self.parent: Thread = parent + self._state: ConnectionState = parent._state self._from_data(data) def __repr__(self) -> str: return f'' - def _from_data(self, data: ThreadMemberPayload): + def _from_data(self, data: ThreadMemberPayload) -> None: try: self.id = int(data['user_id']) except KeyError: - assert self._state.self_id is not None - self.id = self._state.self_id + self.id = self._state.self_id # type: ignore + self.thread_id: int try: self.thread_id = int(data['id']) except KeyError: self.thread_id = self.parent.id - self.joined_at = parse_time(data['join_timestamp']) - self.flags = data['flags'] + self.joined_at: datetime = parse_time(data['join_timestamp']) + self.flags: int = data['flags'] @property def thread(self) -> Thread: diff --git a/discord/types/activity.py b/discord/types/activity.py index c7b2b5dc5..c9fd606ad 100644 --- a/discord/types/activity.py +++ b/discord/types/activity.py @@ -112,3 +112,4 @@ class Activity(_BaseActivity, total=False): session_id: Optional[str] instance: bool buttons: List[ActivityButton] + sync_id: str diff --git a/discord/types/widget.py b/discord/types/widget.py index d0db3b5eb..79e5fef87 100644 --- a/discord/types/widget.py +++ b/discord/types/widget.py @@ -58,3 +58,8 @@ class Widget(TypedDict): class WidgetSettings(TypedDict): enabled: bool channel_id: Optional[Snowflake] + + +class EditWidgetSettings(TypedDict, total=False): + enabled: bool + channel_id: Optional[Snowflake] diff --git a/discord/ui/button.py b/discord/ui/button.py index 163e24ab9..6e6da1fc6 100644 --- a/discord/ui/button.py +++ b/discord/ui/button.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from .view import View from ..emoji import Emoji + from ..types.components import ButtonComponent as ButtonComponentPayload V = TypeVar('V', bound='View', covariant=True) @@ -124,7 +125,7 @@ class Button(Item[V]): style=style, emoji=emoji, ) - self.row = row + self.row: Optional[int] = row @property def style(self) -> ButtonStyle: @@ -132,7 +133,7 @@ class Button(Item[V]): return self._underlying.style @style.setter - def style(self, value: ButtonStyle): + def style(self, value: ButtonStyle) -> None: self._underlying.style = value @property @@ -144,7 +145,7 @@ class Button(Item[V]): return self._underlying.custom_id @custom_id.setter - def custom_id(self, value: Optional[str]): + def custom_id(self, value: Optional[str]) -> None: if value is not None and not isinstance(value, str): raise TypeError('custom_id must be None or str') @@ -156,7 +157,7 @@ class Button(Item[V]): return self._underlying.url @url.setter - def url(self, value: Optional[str]): + def url(self, value: Optional[str]) -> None: if value is not None and not isinstance(value, str): raise TypeError('url must be None or str') self._underlying.url = value @@ -167,7 +168,7 @@ class Button(Item[V]): return self._underlying.disabled @disabled.setter - def disabled(self, value: bool): + def disabled(self, value: bool) -> None: self._underlying.disabled = bool(value) @property @@ -176,7 +177,7 @@ class Button(Item[V]): return self._underlying.label @label.setter - def label(self, value: Optional[str]): + def label(self, value: Optional[str]) -> None: self._underlying.label = str(value) if value is not None else value @property @@ -185,7 +186,7 @@ class Button(Item[V]): return self._underlying.emoji @emoji.setter - def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore + def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]) -> None: if value is not None: if isinstance(value, str): self._underlying.emoji = PartialEmoji.from_str(value) @@ -212,7 +213,7 @@ class Button(Item[V]): def type(self) -> ComponentType: return self._underlying.type - def to_component_dict(self): + def to_component_dict(self) -> ButtonComponentPayload: return self._underlying.to_dict() def is_dispatchable(self) -> bool: diff --git a/discord/ui/item.py b/discord/ui/item.py index 89b706c10..9bace2583 100644 --- a/discord/ui/item.py +++ b/discord/ui/item.py @@ -101,7 +101,7 @@ class Item(Generic[V]): return self._row @row.setter - def row(self, value: Optional[int]): + def row(self, value: Optional[int]) -> None: if value is None: self._row = None elif 5 > value >= 0: @@ -118,7 +118,7 @@ class Item(Generic[V]): """Optional[:class:`View`]: The underlying view for this item.""" return self._view - async def callback(self, interaction: Interaction): + async def callback(self, interaction: Interaction) -> Any: """|coro| The callback associated with this UI item. diff --git a/discord/ui/modal.py b/discord/ui/modal.py index db52cdf6f..1f6e2c42a 100644 --- a/discord/ui/modal.py +++ b/discord/ui/modal.py @@ -38,6 +38,8 @@ from .item import Item from .view import View if TYPE_CHECKING: + from typing_extensions import Self + from ..interactions import Interaction from ..types.interactions import ModalSubmitComponentInteractionData as ModalSubmitComponentInteractionDataPayload @@ -101,7 +103,7 @@ class Modal(View): title: str __discord_ui_modal__ = True - __modal_children_items__: ClassVar[Dict[str, Item]] = {} + __modal_children_items__: ClassVar[Dict[str, Item[Self]]] = {} def __init_subclass__(cls, *, title: str = MISSING) -> None: if title is not MISSING: @@ -139,7 +141,7 @@ class Modal(View): super().__init__(timeout=timeout) - async def on_submit(self, interaction: Interaction): + async def on_submit(self, interaction: Interaction) -> None: """|coro| Called when the modal is submitted. @@ -169,7 +171,7 @@ class Modal(View): print(f'Ignoring exception in modal {self}:', file=sys.stderr) traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr) - def refresh(self, components: Sequence[ModalSubmitComponentInteractionDataPayload]): + def refresh(self, components: Sequence[ModalSubmitComponentInteractionDataPayload]) -> None: for component in components: if component['type'] == 1: self.refresh(component['components']) diff --git a/discord/ui/select.py b/discord/ui/select.py index e1dfd9ea1..13bd542c9 100644 --- a/discord/ui/select.py +++ b/discord/ui/select.py @@ -121,7 +121,7 @@ class Select(Item[V]): options=options, disabled=disabled, ) - self.row = row + self.row: Optional[int] = row @property def custom_id(self) -> str: @@ -129,7 +129,7 @@ class Select(Item[V]): return self._underlying.custom_id @custom_id.setter - def custom_id(self, value: str): + def custom_id(self, value: str) -> None: if not isinstance(value, str): raise TypeError('custom_id must be None or str') @@ -141,7 +141,7 @@ class Select(Item[V]): return self._underlying.placeholder @placeholder.setter - def placeholder(self, value: Optional[str]): + def placeholder(self, value: Optional[str]) -> None: if value is not None and not isinstance(value, str): raise TypeError('placeholder must be None or str') @@ -153,7 +153,7 @@ class Select(Item[V]): return self._underlying.min_values @min_values.setter - def min_values(self, value: int): + def min_values(self, value: int) -> None: self._underlying.min_values = int(value) @property @@ -162,7 +162,7 @@ class Select(Item[V]): return self._underlying.max_values @max_values.setter - def max_values(self, value: int): + def max_values(self, value: int) -> None: self._underlying.max_values = int(value) @property @@ -171,7 +171,7 @@ class Select(Item[V]): return self._underlying.options @options.setter - def options(self, value: List[SelectOption]): + def options(self, value: List[SelectOption]) -> None: if not isinstance(value, list): raise TypeError('options must be a list of SelectOption') if not all(isinstance(obj, SelectOption) for obj in value): @@ -187,7 +187,7 @@ class Select(Item[V]): description: Optional[str] = None, emoji: Optional[Union[str, Emoji, PartialEmoji]] = None, default: bool = False, - ): + ) -> None: """Adds an option to the select menu. To append a pre-existing :class:`discord.SelectOption` use the @@ -226,7 +226,7 @@ class Select(Item[V]): self.append_option(option) - def append_option(self, option: SelectOption): + def append_option(self, option: SelectOption) -> None: """Appends an option to the select menu. Parameters @@ -251,7 +251,7 @@ class Select(Item[V]): return self._underlying.disabled @disabled.setter - def disabled(self, value: bool): + def disabled(self, value: bool) -> None: self._underlying.disabled = bool(value) @property diff --git a/discord/ui/view.py b/discord/ui/view.py index 7e10a9281..234a6e6e1 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -50,6 +50,8 @@ __all__ = ( if TYPE_CHECKING: + from typing_extensions import Self + from ..interactions import Interaction from ..message import Message from ..types.components import Component as ComponentPayload @@ -163,7 +165,7 @@ class View: cls.__view_children_items__ = children - def _init_children(self) -> List[Item]: + def _init_children(self) -> List[Item[Self]]: children = [] for func in self.__view_children_items__: item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__) @@ -175,7 +177,7 @@ class View: def __init__(self, *, timeout: Optional[float] = 180.0): self.timeout = timeout - self.children: List[Item] = self._init_children() + self.children: List[Item[Self]] = self._init_children() self.__weights = _ViewWeights(self.children) self.id: str = os.urandom(16).hex() self.__cancel_callback: Optional[Callable[[View], None]] = None @@ -250,7 +252,7 @@ class View: view.add_item(_component_to_item(component)) return view - def add_item(self, item: Item) -> None: + def add_item(self, item: Item[Any]) -> None: """Adds an item to the view. Parameters @@ -278,7 +280,7 @@ class View: item._view = self self.children.append(item) - def remove_item(self, item: Item) -> None: + def remove_item(self, item: Item[Any]) -> None: """Removes an item from the view. Parameters @@ -334,7 +336,7 @@ class View: """ pass - async def on_error(self, error: Exception, item: Item, interaction: Interaction) -> None: + async def on_error(self, error: Exception, item: Item[Any], interaction: Interaction) -> None: """|coro| A callback that is called when an item's callback or :meth:`interaction_check` @@ -395,16 +397,16 @@ class View: asyncio.create_task(self._scheduled_task(item, interaction), name=f'discord-ui-view-dispatch-{self.id}') - def refresh(self, components: List[Component]): + def refresh(self, components: List[Component]) -> None: # This is pretty hacky at the moment # fmt: off - old_state: Dict[Tuple[int, str], Item] = { + old_state: Dict[Tuple[int, str], Item[Any]] = { (item.type.value, item.custom_id): item # type: ignore for item in self.children if item.is_dispatchable() } # fmt: on - children: List[Item] = [] + children: List[Item[Any]] = [] for component in _walk_all_components(components): try: older = old_state[(component.type.value, component.custom_id)] # type: ignore @@ -494,7 +496,7 @@ class ViewStore: for k in to_remove: del self._views[k] - def add_view(self, view: View, message_id: Optional[int] = None): + def add_view(self, view: View, message_id: Optional[int] = None) -> None: view._start_listening_from_store(self) if view.__discord_ui_modal__: self._modals[view.custom_id] = view # type: ignore @@ -509,7 +511,7 @@ class ViewStore: if message_id is not None: self._synced_message_views[message_id] = view - def remove_view(self, view: View): + def remove_view(self, view: View) -> None: if view.__discord_ui_modal__: self._modals.pop(view.custom_id, None) # type: ignore return @@ -523,7 +525,7 @@ class ViewStore: del self._synced_message_views[key] break - def dispatch_view(self, component_type: int, custom_id: str, interaction: Interaction): + def dispatch_view(self, component_type: int, custom_id: str, interaction: Interaction) -> None: self.__verify_integrity() message_id: Optional[int] = interaction.message and interaction.message.id key = (component_type, message_id, custom_id) @@ -542,7 +544,7 @@ class ViewStore: custom_id: str, interaction: Interaction, components: List[ModalSubmitComponentInteractionDataPayload], - ): + ) -> None: modal = self._modals.get(custom_id) if modal is None: _log.debug("Modal interaction referencing unknown custom_id %s. Discarding", custom_id) @@ -551,13 +553,13 @@ class ViewStore: modal.refresh(components) modal._dispatch_submit(interaction) - def is_message_tracked(self, message_id: int): + def is_message_tracked(self, message_id: int) -> bool: return message_id in self._synced_message_views def remove_message_tracking(self, message_id: int) -> Optional[View]: return self._synced_message_views.pop(message_id, None) - def update_from_message(self, message_id: int, components: List[ComponentPayload]): + def update_from_message(self, message_id: int, components: List[ComponentPayload]) -> None: # pre-req: is_message_tracked == true view = self._synced_message_views[message_id] view.refresh([_component_factory(d) for d in components]) diff --git a/discord/user.py b/discord/user.py index e1bea60df..a4a75e355 100644 --- a/discord/user.py +++ b/discord/user.py @@ -99,10 +99,10 @@ class BaseUser(_UserTag): def __str__(self) -> str: return f'{self.name}#{self.discriminator}' - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, _UserTag) and other.id == self.id - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self) -> int: @@ -444,7 +444,7 @@ class User(BaseUser, discord.abc.Messageable): def __repr__(self) -> str: return f'' - async def _get_channel(self): + async def _get_channel(self) -> DMChannel: ch = await self.create_dm() return ch diff --git a/discord/utils.py b/discord/utils.py index 6c517a380..6bd879fd0 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -29,6 +29,7 @@ from typing import ( Any, AsyncIterable, AsyncIterator, + Awaitable, Callable, Coroutine, Dict, @@ -42,6 +43,7 @@ from typing import ( NamedTuple, Optional, Protocol, + Set, Sequence, Tuple, Type, @@ -66,7 +68,7 @@ import warnings import yarl try: - import orjson + import orjson # type: ignore except ModuleNotFoundError: HAS_ORJSON = False else: @@ -123,7 +125,7 @@ class _cached_property: if TYPE_CHECKING: from functools import cached_property as cached_property - from typing_extensions import ParamSpec + from typing_extensions import ParamSpec, Self from .permissions import Permissions from .abc import Snowflake @@ -135,8 +137,16 @@ if TYPE_CHECKING: P = ParamSpec('P') + MaybeCoroFunc = Union[ + Callable[P, Coroutine[Any, Any, 'T']], + Callable[P, 'T'], + ] + + _SnowflakeListBase = array.array[int] + else: cached_property = _cached_property + _SnowflakeListBase = array.array T = TypeVar('T') @@ -178,7 +188,7 @@ class classproperty(Generic[T_co]): def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co: return self.fget(owner) - def __set__(self, instance, value) -> None: + def __set__(self, instance: Optional[Any], value: Any) -> None: raise AttributeError('cannot set attribute') @@ -210,7 +220,7 @@ class SequenceProxy(Sequence[T_co]): def __reversed__(self) -> Iterator[T_co]: return reversed(self.__proxied) - def index(self, value: Any, *args, **kwargs) -> int: + def index(self, value: Any, *args: Any, **kwargs: Any) -> int: return self.__proxied.index(value, *args, **kwargs) def count(self, value: Any) -> int: @@ -578,7 +588,7 @@ def _is_submodule(parent: str, child: str) -> bool: if HAS_ORJSON: - def _to_json(obj: Any) -> str: # type: ignore + def _to_json(obj: Any) -> str: return orjson.dumps(obj).decode('utf-8') _from_json = orjson.loads # type: ignore @@ -602,15 +612,15 @@ def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: return float(reset_after) -async def maybe_coroutine(f, *args, **kwargs): +async def maybe_coroutine(f: MaybeCoroFunc[P, T], *args: P.args, **kwargs: P.kwargs) -> T: value = f(*args, **kwargs) if _isawaitable(value): return await value else: - return value + return value # type: ignore -async def async_all(gen, *, check=_isawaitable): +async def async_all(gen: Iterable[Awaitable[T]], *, check: Callable[[T], bool] = _isawaitable) -> bool: for elem in gen: if check(elem): elem = await elem @@ -619,7 +629,7 @@ async def async_all(gen, *, check=_isawaitable): return True -async def sane_wait_for(futures, *, timeout): +async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: Optional[float]) -> Set[asyncio.Task[T]]: ensured = [asyncio.ensure_future(fut) for fut in futures] done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) @@ -637,7 +647,7 @@ def get_slots(cls: Type[Any]) -> Iterator[str]: continue -def compute_timedelta(dt: datetime.datetime): +def compute_timedelta(dt: datetime.datetime) -> float: if dt.tzinfo is None: dt = dt.astimezone() now = datetime.datetime.now(datetime.timezone.utc) @@ -686,7 +696,7 @@ def valid_icon_size(size: int) -> bool: return not size & (size - 1) and 4096 >= size >= 16 -class SnowflakeList(array.array): +class SnowflakeList(_SnowflakeListBase): """Internal data storage class to efficiently store a list of snowflakes. This should have the following characteristics: @@ -705,7 +715,7 @@ class SnowflakeList(array.array): def __init__(self, data: Iterable[int], *, is_sorted: bool = False): ... - def __new__(cls, data: Iterable[int], *, is_sorted: bool = False): + def __new__(cls, data: Iterable[int], *, is_sorted: bool = False) -> Self: return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore def add(self, element: int) -> None: @@ -1010,7 +1020,7 @@ def evaluate_annotation( cache: Dict[str, Any], *, implicit_str: bool = True, -): +) -> Any: if isinstance(tp, ForwardRef): tp = tp.__forward_arg__ # ForwardRefs always evaluate their internals diff --git a/discord/voice_client.py b/discord/voice_client.py index 208bb78d3..67490e084 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -262,7 +262,7 @@ class VoiceClient(VoiceProtocol): self._lite_nonce: int = 0 self.ws: DiscordVoiceWebSocket = MISSING - warn_nacl = not has_nacl + warn_nacl: bool = not has_nacl supported_modes: Tuple[SupportedModes, ...] = ( 'xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', @@ -279,7 +279,7 @@ class VoiceClient(VoiceProtocol): """:class:`ClientUser`: The user connected to voice (i.e. ourselves).""" return self._state.user # type: ignore - user can't be None after login - def checked_add(self, attr, value, limit): + def checked_add(self, attr: str, value: int, limit: int) -> None: val = getattr(self, attr) if val + value > limit: setattr(self, attr, 0) @@ -289,7 +289,7 @@ class VoiceClient(VoiceProtocol): # connection related async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None: - self.session_id = data['session_id'] + self.session_id: str = data['session_id'] channel_id = data['channel_id'] if not self._handshaking or self._potentially_reconnecting: @@ -323,12 +323,12 @@ class VoiceClient(VoiceProtocol): self.endpoint, _, _ = endpoint.rpartition(':') if self.endpoint.startswith('wss://'): # Just in case, strip it off since we're going to add it later - self.endpoint = self.endpoint[6:] + self.endpoint: str = self.endpoint[6:] # This gets set later self.endpoint_ip = MISSING - self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket.setblocking(False) if not self._handshaking: diff --git a/discord/webhook/async_.py b/discord/webhook/async_.py index 2d06a953b..f7127c830 100644 --- a/discord/webhook/async_.py +++ b/discord/webhook/async_.py @@ -30,7 +30,7 @@ import json import re from urllib.parse import quote as urlquote -from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload +from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, TypeVar, Type, overload from contextvars import ContextVar import weakref @@ -43,7 +43,7 @@ from ..enums import try_enum, WebhookType from ..user import BaseUser, User from ..flags import MessageFlags from ..asset import Asset -from ..http import Route, handle_message_parameters, MultipartParameters +from ..http import Route, handle_message_parameters, MultipartParameters, HTTPClient from ..mixins import Hashable from ..channel import PartialMessageable from ..file import File @@ -58,24 +58,38 @@ __all__ = ( _log = logging.getLogger(__name__) if TYPE_CHECKING: + from typing_extensions import Self + from types import TracebackType + from ..embeds import Embed from ..mentions import AllowedMentions from ..message import Attachment from ..state import ConnectionState from ..http import Response + from ..guild import Guild + from ..channel import TextChannel + from ..abc import Snowflake + from ..ui.view import View + import datetime from ..types.webhook import ( Webhook as WebhookPayload, + SourceGuild as SourceGuildPayload, ) from ..types.message import ( Message as MessagePayload, ) - from ..guild import Guild - from ..channel import TextChannel - from ..abc import Snowflake - from ..ui.view import View - import datetime + from ..types.user import ( + User as UserPayload, + PartialUser as PartialUserPayload, + ) + from ..types.channel import ( + PartialChannel as PartialChannelPayload, + ) + + BE = TypeVar('BE', bound=BaseException) + _State = Union[ConnectionState, '_WebhookState'] -MISSING = utils.MISSING +MISSING: Any = utils.MISSING class AsyncDeferredLock: @@ -83,14 +97,19 @@ class AsyncDeferredLock: self.lock = lock self.delta: Optional[float] = None - async def __aenter__(self): + async def __aenter__(self) -> Self: await self.lock.acquire() return self def delay_by(self, delta: float) -> None: self.delta = delta - async def __aexit__(self, type, value, traceback): + async def __aexit__( + self, + exc_type: Optional[Type[BE]], + exc: Optional[BE], + traceback: Optional[TracebackType], + ) -> None: if self.delta: await asyncio.sleep(self.delta) self.lock.release() @@ -545,11 +564,11 @@ class PartialWebhookChannel(Hashable): __slots__ = ('id', 'name') - def __init__(self, *, data): - self.id = int(data['id']) - self.name = data['name'] + def __init__(self, *, data: PartialChannelPayload) -> None: + self.id: int = int(data['id']) + self.name: str = data['name'] - def __repr__(self): + def __repr__(self) -> str: return f'' @@ -570,13 +589,13 @@ class PartialWebhookGuild(Hashable): __slots__ = ('id', 'name', '_icon', '_state') - def __init__(self, *, data, state): - self._state = state - self.id = int(data['id']) - self.name = data['name'] - self._icon = data['icon'] + def __init__(self, *, data: SourceGuildPayload, state: _State) -> None: + self._state: _State = state + self.id: int = int(data['id']) + self.name: str = data['name'] + self._icon: str = data['icon'] - def __repr__(self): + def __repr__(self) -> str: return f'' @property @@ -590,14 +609,14 @@ class PartialWebhookGuild(Hashable): class _FriendlyHttpAttributeErrorHelper: __slots__ = () - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: raise AttributeError('PartialWebhookState does not support http methods.') class _WebhookState: __slots__ = ('_parent', '_webhook') - def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]): + def __init__(self, webhook: Any, parent: Optional[_State]): self._webhook: Any = webhook self._parent: Optional[ConnectionState] @@ -606,23 +625,23 @@ class _WebhookState: else: self._parent = parent - def _get_guild(self, guild_id): + def _get_guild(self, guild_id: Optional[int]) -> Optional[Guild]: if self._parent is not None: return self._parent._get_guild(guild_id) return None - def store_user(self, data): + def store_user(self, data: Union[UserPayload, PartialUserPayload]) -> BaseUser: if self._parent is not None: return self._parent.store_user(data) # state parameter is artificial return BaseUser(state=self, data=data) # type: ignore - def create_user(self, data): + def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> BaseUser: # state parameter is artificial return BaseUser(state=self, data=data) # type: ignore @property - def http(self): + def http(self) -> Union[HTTPClient, _FriendlyHttpAttributeErrorHelper]: if self._parent is not None: return self._parent.http @@ -630,7 +649,7 @@ class _WebhookState: # however, using it should result in a late-binding error. return _FriendlyHttpAttributeErrorHelper() - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: if self._parent is not None: return getattr(self._parent, attr) @@ -830,19 +849,24 @@ class BaseWebhook(Hashable): '_state', ) - def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None): + def __init__( + self, + data: WebhookPayload, + token: Optional[str] = None, + state: Optional[_State] = None, + ) -> None: self.auth_token: Optional[str] = token - self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(self, parent=state) + self._state: _State = state or _WebhookState(self, parent=state) self._update(data) - def _update(self, data: WebhookPayload): - self.id = int(data['id']) - self.type = try_enum(WebhookType, int(data['type'])) - self.channel_id = utils._get_as_snowflake(data, 'channel_id') - self.guild_id = utils._get_as_snowflake(data, 'guild_id') - self.name = data.get('name') - self._avatar = data.get('avatar') - self.token = data.get('token') + def _update(self, data: WebhookPayload) -> None: + self.id: int = int(data['id']) + self.type: WebhookType = try_enum(WebhookType, int(data['type'])) + self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id') + self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id') + self.name: Optional[str] = data.get('name') + self._avatar: Optional[str] = data.get('avatar') + self.token: Optional[str] = data.get('token') user = data.get('user') self.user: Optional[Union[BaseUser, User]] = None @@ -1010,11 +1034,17 @@ class Webhook(BaseWebhook): __slots__: Tuple[str, ...] = ('session',) - def __init__(self, data: WebhookPayload, session: aiohttp.ClientSession, token: Optional[str] = None, state=None): + def __init__( + self, + data: WebhookPayload, + session: aiohttp.ClientSession, + token: Optional[str] = None, + state: Optional[_State] = None, + ) -> None: super().__init__(data, token, state) - self.session = session + self.session: aiohttp.ClientSession = session - def __repr__(self): + def __repr__(self) -> str: return f'' @property @@ -1023,7 +1053,7 @@ class Webhook(BaseWebhook): return f'https://discord.com/api/webhooks/{self.id}/{self.token}' @classmethod - def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook: + def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Self: """Creates a partial :class:`Webhook`. Parameters @@ -1059,7 +1089,7 @@ class Webhook(BaseWebhook): return cls(data, session, token=bot_token) @classmethod - def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook: + def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Self: """Creates a partial :class:`Webhook` from a webhook URL. .. versionchanged:: 2.0 @@ -1102,7 +1132,7 @@ class Webhook(BaseWebhook): return cls(data, session, token=bot_token) # type: ignore @classmethod - def _as_follower(cls, data, *, channel, user) -> Webhook: + def _as_follower(cls, data, *, channel, user) -> Self: name = f"{channel.guild} #{channel}" feed: WebhookPayload = { 'id': data['webhook_id'], @@ -1118,8 +1148,8 @@ class Webhook(BaseWebhook): return cls(feed, session=session, state=state, token=state.http.token) @classmethod - def from_state(cls, data, state) -> Webhook: - session = state.http._HTTPClient__session + def from_state(cls, data: WebhookPayload, state: ConnectionState) -> Self: + session = state.http._HTTPClient__session # type: ignore return cls(data, session=session, state=state, token=state.http.token) async def fetch(self, *, prefer_auth: bool = True) -> Webhook: @@ -1168,7 +1198,7 @@ class Webhook(BaseWebhook): return Webhook(data, self.session, token=self.auth_token, state=self._state) - async def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True): + async def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True) -> None: """|coro| Deletes this Webhook. diff --git a/discord/webhook/sync.py b/discord/webhook/sync.py index ff3983136..28ae496a3 100644 --- a/discord/webhook/sync.py +++ b/discord/webhook/sync.py @@ -37,7 +37,7 @@ import time import re from urllib.parse import quote as urlquote -from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload +from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, TypeVar, Type, overload import weakref from .. import utils @@ -56,36 +56,50 @@ __all__ = ( _log = logging.getLogger(__name__) if TYPE_CHECKING: + from typing_extensions import Self + from types import TracebackType + from ..file import File from ..embeds import Embed from ..mentions import AllowedMentions from ..message import Attachment + from ..abc import Snowflake + from ..state import ConnectionState from ..types.webhook import ( Webhook as WebhookPayload, ) - from ..abc import Snowflake + from ..types.message import ( + Message as MessagePayload, + ) + + BE = TypeVar('BE', bound=BaseException) try: from requests import Session, Response except ModuleNotFoundError: pass -MISSING = utils.MISSING +MISSING: Any = utils.MISSING class DeferredLock: - def __init__(self, lock: threading.Lock): - self.lock = lock + def __init__(self, lock: threading.Lock) -> None: + self.lock: threading.Lock = lock self.delta: Optional[float] = None - def __enter__(self): + def __enter__(self) -> Self: self.lock.acquire() return self def delay_by(self, delta: float) -> None: self.delta = delta - def __exit__(self, type, value, traceback): + def __exit__( + self, + exc_type: Optional[Type[BE]], + exc: Optional[BE], + traceback: Optional[TracebackType], + ) -> None: if self.delta: time.sleep(self.delta) self.lock.release() @@ -218,7 +232,7 @@ class WebhookAdapter: token: Optional[str] = None, session: Session, reason: Optional[str] = None, - ): + ) -> None: route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id) return self.request(route, session, reason=reason, auth_token=token) @@ -229,7 +243,7 @@ class WebhookAdapter: *, session: Session, reason: Optional[str] = None, - ): + ) -> None: route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) return self.request(route, session, reason=reason) @@ -241,7 +255,7 @@ class WebhookAdapter: *, session: Session, reason: Optional[str] = None, - ): + ) -> WebhookPayload: route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id) return self.request(route, session, reason=reason, payload=payload, auth_token=token) @@ -253,7 +267,7 @@ class WebhookAdapter: *, session: Session, reason: Optional[str] = None, - ): + ) -> WebhookPayload: route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) return self.request(route, session, reason=reason, payload=payload) @@ -268,7 +282,7 @@ class WebhookAdapter: files: Optional[List[File]] = None, thread_id: Optional[int] = None, wait: bool = False, - ): + ) -> MessagePayload: params = {'wait': int(wait)} if thread_id: params['thread_id'] = thread_id @@ -282,7 +296,7 @@ class WebhookAdapter: message_id: int, *, session: Session, - ): + ) -> MessagePayload: route = Route( 'GET', '/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}', @@ -302,7 +316,7 @@ class WebhookAdapter: payload: Optional[Dict[str, Any]] = None, multipart: Optional[List[Dict[str, Any]]] = None, files: Optional[List[File]] = None, - ): + ) -> MessagePayload: route = Route( 'PATCH', '/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}', @@ -319,7 +333,7 @@ class WebhookAdapter: message_id: int, *, session: Session, - ): + ) -> None: route = Route( 'DELETE', '/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}', @@ -335,7 +349,7 @@ class WebhookAdapter: token: str, *, session: Session, - ): + ) -> WebhookPayload: route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id) return self.request(route, session=session, auth_token=token) @@ -345,7 +359,7 @@ class WebhookAdapter: token: str, *, session: Session, - ): + ) -> WebhookPayload: route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) return self.request(route, session=session) @@ -569,11 +583,17 @@ class SyncWebhook(BaseWebhook): __slots__: Tuple[str, ...] = ('session',) - def __init__(self, data: WebhookPayload, session: Session, token: Optional[str] = None, state=None): + def __init__( + self, + data: WebhookPayload, + session: Session, + token: Optional[str] = None, + state: Optional[Union[ConnectionState, _WebhookState]] = None, + ) -> None: super().__init__(data, token, state) - self.session = session + self.session: Session = session - def __repr__(self): + def __repr__(self) -> str: return f'' @property @@ -812,7 +832,7 @@ class SyncWebhook(BaseWebhook): return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state) - def _create_message(self, data): + def _create_message(self, data: MessagePayload) -> SyncWebhookMessage: state = _WebhookState(self, parent=self._state) # state may be artificial (unlikely at this point...) channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore diff --git a/discord/widget.py b/discord/widget.py index 0a9736cad..5ea0dcbd5 100644 --- a/discord/widget.py +++ b/discord/widget.py @@ -278,7 +278,7 @@ class Widget: def __str__(self) -> str: return self.json_url - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, Widget): return self.id == other.id return False