From 4e331704aed2e59b0c6d9ce1ec904e01ce9accd7 Mon Sep 17 00:00:00 2001 From: Vaskel <49348256+ImVaskel@users.noreply.github.com> Date: Fri, 18 Feb 2022 23:04:56 -0500 Subject: [PATCH] [commands] Fix typing problems in commands.converter --- discord/ext/commands/converter.py | 71 +++++++++++++++++-------------- discord/ext/commands/errors.py | 6 +-- 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index f314a57df..38a08ecd3 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -49,6 +49,9 @@ from .errors import * if TYPE_CHECKING: from .context import Context from discord.message import PartialMessageableChannel + from .bot import Bot, AutoShardedBot + + _Bot = Union[Bot, AutoShardedBot] __all__ = ( @@ -157,7 +160,7 @@ class ObjectConverter(IDConverter[discord.Object]): 2. Lookup by member, role, or channel mention. """ - async def convert(self, ctx: Context, argument: str) -> discord.Object: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Object: match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument) if match is None: @@ -221,7 +224,7 @@ class MemberConverter(IDConverter[discord.Member]): return None return members[0] - async def convert(self, ctx: Context, argument: str) -> discord.Member: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Member: bot = ctx.bot match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) guild = ctx.guild @@ -252,7 +255,7 @@ class MemberConverter(IDConverter[discord.Member]): if not result: raise MemberNotFound(argument) - return result + return result # type: ignore class UserConverter(IDConverter[discord.User]): @@ -275,7 +278,7 @@ class UserConverter(IDConverter[discord.User]): and it's not available in cache. """ - async def convert(self, ctx: Context, argument: str) -> discord.User: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.User: match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) result = None state = ctx._state @@ -289,7 +292,7 @@ class UserConverter(IDConverter[discord.User]): except discord.HTTPException: raise UserNotFound(argument) from None - return result + return result # type: ignore arg = argument @@ -352,7 +355,9 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): return guild_id, message_id, channel_id @staticmethod - def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]: + def _resolve_channel( + ctx: Context[_Bot], guild_id: Optional[int], channel_id: Optional[int] + ) -> Optional[PartialMessageableChannel]: if channel_id is None: # we were passed just a message id so we can assume the channel is the current context channel return ctx.channel @@ -365,7 +370,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): return ctx.bot.get_channel(channel_id) - async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialMessage: guild_id, message_id, channel_id = self._get_id_matches(ctx, argument) channel = self._resolve_channel(ctx, guild_id, channel_id) if not channel: @@ -388,7 +393,7 @@ class MessageConverter(IDConverter[discord.Message]): Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context, argument: str) -> discord.Message: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Message: guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument) message = ctx.bot._connection._get_message(message_id) if message: @@ -419,7 +424,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): .. versionadded:: 2.0 """ - async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.abc.GuildChannel: return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel) @staticmethod @@ -469,7 +474,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): else: thread_id = int(match.group(1)) if guild: - result = guild.get_thread(thread_id) + result = guild.get_thread(thread_id) # type: ignore if not result or not isinstance(result, type): raise ThreadNotFound(argument) @@ -493,7 +498,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.TextChannel: return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel) @@ -513,7 +518,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.VoiceChannel: return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel) @@ -532,7 +537,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): 3. Lookup by name """ - async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StageChannel: return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel) @@ -552,7 +557,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.CategoryChannel: return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel) @@ -571,7 +576,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): .. versionadded:: 1.7 """ - async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StoreChannel: return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel) @@ -589,7 +594,7 @@ class ThreadConverter(IDConverter[discord.Thread]): .. versionadded: 2.0 """ - async def convert(self, ctx: Context, argument: str) -> discord.Thread: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Thread: return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread) @@ -654,7 +659,7 @@ class ColourConverter(Converter[discord.Colour]): blue = self.parse_rgb_number(argument, match.group('b')) return discord.Color.from_rgb(red, green, blue) - async def convert(self, ctx: Context, argument: str) -> discord.Colour: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Colour: if argument[0] == '#': return self.parse_hex_number(argument[1:]) @@ -695,7 +700,7 @@ class RoleConverter(IDConverter[discord.Role]): Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context, argument: str) -> discord.Role: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Role: guild = ctx.guild if not guild: raise NoPrivateMessage() @@ -714,7 +719,7 @@ class RoleConverter(IDConverter[discord.Role]): class GameConverter(Converter[discord.Game]): """Converts to :class:`~discord.Game`.""" - async def convert(self, ctx: Context, argument: str) -> discord.Game: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Game: return discord.Game(name=argument) @@ -727,7 +732,7 @@ class InviteConverter(Converter[discord.Invite]): Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context, argument: str) -> discord.Invite: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Invite: try: invite = await ctx.bot.fetch_invite(argument) return invite @@ -746,7 +751,7 @@ class GuildConverter(IDConverter[discord.Guild]): .. versionadded:: 1.7 """ - async def convert(self, ctx: Context, argument: str) -> discord.Guild: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Guild: match = self._get_id_match(argument) result = None @@ -778,7 +783,7 @@ class EmojiConverter(IDConverter[discord.Emoji]): Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context, argument: str) -> discord.Emoji: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Emoji: match = self._get_id_match(argument) or re.match(r'$', argument) result = None bot = ctx.bot @@ -812,7 +817,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]): Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialEmoji: match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument) if match: @@ -841,7 +846,7 @@ class GuildStickerConverter(IDConverter[discord.GuildSticker]): .. versionadded:: 2.0 """ - async def convert(self, ctx: Context, argument: str) -> discord.GuildSticker: + async def convert(self, ctx: Context[_Bot], argument: str) -> discord.GuildSticker: match = self._get_id_match(argument) result = None bot = ctx.bot @@ -899,17 +904,17 @@ class clean_content(Converter[str]): self.escape_markdown = escape_markdown self.remove_markdown = remove_markdown - async def convert(self, ctx: Context, argument: str) -> str: + async def convert(self, ctx: Context[_Bot], argument: str) -> str: msg = ctx.message if ctx.guild: def resolve_member(id: int) -> str: - m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id) + m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id) # type: ignore return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user' def resolve_role(id: int) -> str: - r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id) + r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id) # type: ignore return f'@{r.name}' if r else '@deleted-role' else: @@ -924,7 +929,7 @@ class clean_content(Converter[str]): if self.fix_channel_mentions and ctx.guild: def resolve_channel(id: int) -> str: - c = ctx.guild.get_channel(id) + c = ctx.guild.get_channel(id) # type: ignore return f'#{c.name}' if c else '#deleted-channel' else: @@ -1000,7 +1005,7 @@ class Greedy(List[T]): raise TypeError('Greedy[...] expects a type or a Converter instance.') if converter in (str, type(None)) or origin is Greedy: - raise TypeError(f'Greedy[{converter.__name__}] is invalid.') + raise TypeError(f'Greedy[{converter.__name__}] is invalid.') # type: ignore if origin is Union and type(None) in args: raise TypeError(f'Greedy[{converter!r}] is invalid.') @@ -1076,13 +1081,13 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp if inspect.ismethod(converter.convert): return await converter.convert(ctx, argument) else: - return await converter().convert(ctx, argument) + return await converter().convert(ctx, argument) # type: ignore elif isinstance(converter, Converter): - return await converter.convert(ctx, argument) + return await converter.convert(ctx, argument) # type: ignore except CommandError: raise except Exception as exc: - raise ConversionError(converter, exc) from exc + raise ConversionError(converter, exc) from exc # type: ignore try: return converter(argument) @@ -1092,7 +1097,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp try: name = converter.__name__ except AttributeError: - name = converter.__class__.__name__ + name = converter.__class__.__name__ # type: ignore raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 938343857..11d442416 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -346,11 +346,11 @@ class ChannelNotFound(BadArgument): Attributes ----------- - argument: :class:`str` + argument: Union[:class:`int`, :class:`str`] The channel supplied by the caller that was not found """ - def __init__(self, argument: str) -> None: - self.argument: str = argument + def __init__(self, argument: Union[int, str]) -> None: + self.argument: Union[int, str] = argument super().__init__(f'Channel "{argument}" not found.') class ThreadNotFound(BadArgument):