diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index f9a8592cf..730113b14 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -26,7 +26,7 @@ from __future__ import annotations import re import inspect -from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, Union, runtime_checkable +from typing import TYPE_CHECKING, List, Protocol, TypeVar, Tuple, Union, runtime_checkable import discord from .errors import * @@ -70,11 +70,12 @@ def _get_from_guilds(bot, getter, argument): _utils_get = discord.utils.get -T = TypeVar('T', covariant=True) +T = TypeVar('T') +T_co = TypeVar('T_co', covariant=True) @runtime_checkable -class Converter(Protocol[T]): +class Converter(Protocol[T_co]): """The base class of custom converters that require the :class:`.Context` to be passed to be useful. @@ -85,7 +86,7 @@ class Converter(Protocol[T]): method to do its conversion logic. This method must be a :ref:`coroutine `. """ - async def convert(self, ctx: Context, argument: str) -> T: + async def convert(self, ctx: Context, argument: str) -> T_co: """|coro| The method to override to do conversion logic. @@ -110,7 +111,8 @@ class Converter(Protocol[T]): """ raise NotImplementedError('Derived classes need to implement this.') -class IDConverter(Converter[T]): + +class IDConverter(Converter[T_co]): def __init__(self): self._id_regex = re.compile(r'([0-9]{15,20})$') super().__init__() @@ -118,6 +120,7 @@ class IDConverter(Converter[T]): def _get_id_match(self, argument): return self._id_regex.match(argument) + class MemberConverter(IDConverter[discord.Member]): """Converts to a :class:`~discord.Member`. @@ -204,6 +207,7 @@ class MemberConverter(IDConverter[discord.Member]): return result + class UserConverter(IDConverter[discord.User]): """Converts to a :class:`~discord.User`. @@ -223,6 +227,7 @@ class UserConverter(IDConverter[discord.User]): This converter now lazily fetches users from the HTTP APIs if an ID is passed and it's not available in cache. """ + async def convert(self, ctx: Context, argument: str) -> discord.User: match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument) result = None @@ -263,6 +268,7 @@ class UserConverter(IDConverter[discord.User]): return result + class PartialMessageConverter(Converter[discord.PartialMessage]): """Converts to a :class:`discord.PartialMessage`. @@ -274,6 +280,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): 2. By message ID (The message is assumed to be in the context channel.) 3. By message URL """ + @staticmethod def _get_id_matches(argument): id_regex = re.compile(r'(?:(?P[0-9]{15,20})-)?(?P[0-9]{15,20})$') @@ -285,8 +292,8 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): match = id_regex.match(argument) or link_regex.match(argument) if not match: raise MessageNotFound(argument) - channel_id = match.group("channel_id") - return int(match.group("message_id")), int(channel_id) if channel_id else None + channel_id = match.group('channel_id') + return int(match.group('message_id')), int(channel_id) if channel_id else None async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage: message_id, channel_id = self._get_id_matches(argument) @@ -295,6 +302,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): raise ChannelNotFound(channel_id) return discord.PartialMessage(channel=channel, id=message_id) + class MessageConverter(IDConverter[discord.Message]): """Converts to a :class:`discord.Message`. @@ -309,6 +317,7 @@ class MessageConverter(IDConverter[discord.Message]): .. versionchanged:: 1.5 Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument` """ + async def convert(self, ctx: Context, argument: str) -> discord.Message: message_id, channel_id = PartialMessageConverter._get_id_matches(argument) message = ctx.bot._connection._get_message(message_id) @@ -324,6 +333,7 @@ class MessageConverter(IDConverter[discord.Message]): except discord.Forbidden: raise ChannelNotReadable(channel) + class TextChannelConverter(IDConverter[discord.TextChannel]): """Converts to a :class:`~discord.TextChannel`. @@ -339,6 +349,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): .. versionchanged:: 1.5 Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` """ + async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: bot = ctx.bot @@ -351,8 +362,10 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): if guild: result = discord.utils.get(guild.text_channels, name=argument) else: + def check(c): return isinstance(c, discord.TextChannel) and c.name == argument + result = discord.utils.find(check, bot.get_all_channels()) else: channel_id = int(match.group(1)) @@ -366,6 +379,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): return result + class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): """Converts to a :class:`~discord.VoiceChannel`. @@ -381,6 +395,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): .. versionchanged:: 1.5 Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` """ + async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: bot = ctx.bot match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) @@ -392,8 +407,10 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): if guild: result = discord.utils.get(guild.voice_channels, name=argument) else: + def check(c): return isinstance(c, discord.VoiceChannel) and c.name == argument + result = discord.utils.find(check, bot.get_all_channels()) else: channel_id = int(match.group(1)) @@ -407,6 +424,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): return result + class StageChannelConverter(IDConverter[discord.StageChannel]): """Converts to a :class:`~discord.StageChannel`. @@ -421,6 +439,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): 2. Lookup by mention. 3. Lookup by name """ + async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: bot = ctx.bot match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) @@ -432,8 +451,10 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): if guild: result = discord.utils.get(guild.stage_channels, name=argument) else: + def check(c): return isinstance(c, discord.StageChannel) and c.name == argument + result = discord.utils.find(check, bot.get_all_channels()) else: channel_id = int(match.group(1)) @@ -447,6 +468,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): return result + class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): """Converts to a :class:`~discord.CategoryChannel`. @@ -462,6 +484,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): .. versionchanged:: 1.5 Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` """ + async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: bot = ctx.bot @@ -474,8 +497,10 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): if guild: result = discord.utils.get(guild.categories, name=argument) else: + def check(c): return isinstance(c, discord.CategoryChannel) and c.name == argument + result = discord.utils.find(check, bot.get_all_channels()) else: channel_id = int(match.group(1)) @@ -489,6 +514,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): return result + class StoreChannelConverter(IDConverter[discord.StoreChannel]): """Converts to a :class:`~discord.StoreChannel`. @@ -515,8 +541,10 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): if guild: result = discord.utils.get(guild.channels, name=argument) else: + def check(c): return isinstance(c, discord.StoreChannel) and c.name == argument + result = discord.utils.find(check, bot.get_all_channels()) else: channel_id = int(match.group(1)) @@ -530,6 +558,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): return result + class ColourConverter(Converter[discord.Colour]): """Converts to a :class:`~discord.Colour`. @@ -612,8 +641,10 @@ class ColourConverter(Converter[discord.Colour]): raise BadColourArgument(arg) return method() + ColorConverter = ColourConverter + class RoleConverter(IDConverter[discord.Role]): """Converts to a :class:`~discord.Role`. @@ -629,6 +660,7 @@ class RoleConverter(IDConverter[discord.Role]): .. versionchanged:: 1.5 Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument` """ + async def convert(self, ctx: Context, argument: str) -> discord.Role: guild = ctx.guild if not guild: @@ -644,11 +676,14 @@ class RoleConverter(IDConverter[discord.Role]): raise RoleNotFound(argument) return result + class GameConverter(Converter[discord.Game]): """Converts to :class:`~discord.Game`.""" + async def convert(self, ctx: Context, argument: str) -> discord.Game: return discord.Game(name=argument) + class InviteConverter(Converter[discord.Invite]): """Converts to a :class:`~discord.Invite`. @@ -657,6 +692,7 @@ class InviteConverter(Converter[discord.Invite]): .. versionchanged:: 1.5 Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument` """ + async def convert(self, ctx: Context, argument: str) -> discord.Invite: try: invite = await ctx.bot.fetch_invite(argument) @@ -664,6 +700,7 @@ class InviteConverter(Converter[discord.Invite]): except Exception as exc: raise BadInviteArgument() from exc + class GuildConverter(IDConverter[discord.Guild]): """Converts to a :class:`~discord.Guild`. @@ -690,6 +727,7 @@ class GuildConverter(IDConverter[discord.Guild]): raise GuildNotFound(argument) return result + class EmojiConverter(IDConverter[discord.Emoji]): """Converts to a :class:`~discord.Emoji`. @@ -705,6 +743,7 @@ class EmojiConverter(IDConverter[discord.Emoji]): .. versionchanged:: 1.5 Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument` """ + async def convert(self, ctx: Context, argument: str) -> discord.Emoji: match = self._get_id_match(argument) or re.match(r'$', argument) result = None @@ -733,6 +772,7 @@ class EmojiConverter(IDConverter[discord.Emoji]): return result + class PartialEmojiConverter(Converter[discord.PartialEmoji]): """Converts to a :class:`~discord.PartialEmoji`. @@ -741,6 +781,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]): .. versionchanged:: 1.5 Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument` """ + async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji: match = re.match(r'<(a?):([a-zA-Z0-9\_]+):([0-9]+)>$', argument) @@ -749,11 +790,13 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]): emoji_name = match.group(2) emoji_id = int(match.group(3)) - return discord.PartialEmoji.with_state(ctx.bot._connection, animated=emoji_animated, name=emoji_name, - id=emoji_id) + return discord.PartialEmoji.with_state( + ctx.bot._connection, animated=emoji_animated, name=emoji_name, id=emoji_id + ) raise PartialEmojiConversionFailure(argument) + class clean_content(Converter[str]): """Converts the argument to mention scrubbed version of said content. @@ -773,6 +816,7 @@ class clean_content(Converter[str]): .. versionadded:: 1.7 """ + def __init__(self, *, fix_channel_mentions=False, use_nicknames=True, escape_markdown=False, remove_markdown=False): self.fix_channel_mentions = fix_channel_mentions self.use_nicknames = use_nicknames @@ -784,6 +828,7 @@ class clean_content(Converter[str]): transformations = {} if self.fix_channel_mentions and ctx.guild: + def resolve_channel(id, *, _get=ctx.guild.get_channel): ch = _get(id) return f'<#{id}>', ('#' + ch.name if ch else '#deleted-channel') @@ -791,15 +836,18 @@ class clean_content(Converter[str]): transformations.update(resolve_channel(channel) for channel in message.raw_channel_mentions) if self.use_nicknames and ctx.guild: + def resolve_member(id, *, _get=ctx.guild.get_member): m = _get(id) return '@' + m.display_name if m else '@deleted-user' + else: + def resolve_member(id, *, _get=ctx.bot.get_user): m = _get(id) return '@' + m.name if m else '@deleted-user' - + # fmt: off transformations.update( (f'<@{member_id}>', resolve_member(member_id)) for member_id in message.raw_mentions @@ -809,8 +857,10 @@ class clean_content(Converter[str]): (f'<@!{member_id}>', resolve_member(member_id)) for member_id in message.raw_mentions ) + # fmt: on if ctx.guild: + def resolve_role(_id, *, _find=ctx.guild.get_role): r = _find(_id) return '@' + r.name if r else '@deleted-role' @@ -818,7 +868,7 @@ class clean_content(Converter[str]): transformations.update( (f'<@&{role_id}>', resolve_role(role_id)) for role_id in message.raw_role_mentions - ) + ) # fmt: off def repl(obj): return transformations.get(obj.group(0), '') @@ -834,28 +884,51 @@ class clean_content(Converter[str]): # Completely ensure no mentions escape: return discord.utils.escape_mentions(result) -class _Greedy: + +class Greedy(List[T]): + r"""A special converter that greedily consumes arguments until it can't. + As a consequence of this behaviour, most input errors are silently discarded, + since it is used as an indicator of when to stop parsing. + + When a parser error is met the greedy converter stops converting, undoes the + internal string parsing routine, and continues parsing regularly. + + For example, in the following code: + + .. code-block:: python3 + + @commands.command() + async def test(ctx, numbers: Greedy[int], reason: str): + await ctx.send("numbers: {}, reason: {}".format(numbers, reason)) + + An invocation of ``[p]test 1 2 3 4 5 6 hello`` would pass ``numbers`` with + ``[1, 2, 3, 4, 5, 6]`` and ``reason`` with ``hello``\. + + For more information, check :ref:`ext_commands_special_converters`. + """ + __slots__ = ('converter',) - def __init__(self, *, converter=None): + def __init__(self, *, converter: T): self.converter = converter - def __getitem__(self, params): + def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]: if not isinstance(params, tuple): params = (params,) if len(params) != 1: raise TypeError('Greedy[...] only takes a single argument') converter = params[0] - if not (callable(converter) or isinstance(converter, Converter) or hasattr(converter, '__origin__')): + origin = getattr(converter, '__origin__', None) + args = getattr(converter, '__args__', ()) + + if not (callable(converter) or isinstance(converter, Converter) or origin is not None): raise TypeError('Greedy[...] expects a type or a Converter instance.') - if converter is str or converter is type(None) or converter is _Greedy: + if converter in (str, type(None)) or origin is Greedy: raise TypeError(f'Greedy[{converter.__name__}] is invalid.') - if getattr(converter, '__origin__', None) is Union and type(None) in converter.__args__: + if origin is Union and type(None) in args: raise TypeError(f'Greedy[{converter!r}] is invalid.') - return self.__class__(converter=converter) - -Greedy = _Greedy() + return cls(converter=converter) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index d735c77ba..9bd8f4f81 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -560,7 +560,7 @@ class Command(_BaseCommand): # The greedy converter is simple -- it keeps going until it fails in which case, # it undos the view ready for the next parameter to use instead - if type(converter) is converters._Greedy: + if isinstance(converter, converters.Greedy): if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY: return await self._transform_greedy_pos(ctx, param, required, converter.converter) elif param.kind == param.VAR_POSITIONAL: @@ -1042,7 +1042,7 @@ class Command(_BaseCommand): result = [] for name, param in params.items(): - greedy = isinstance(param.annotation, converters._Greedy) + greedy = isinstance(param.annotation, converters.Greedy) optional = False # postpone evaluation of if it's an optional argument # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the @@ -1059,7 +1059,6 @@ class Command(_BaseCommand): if origin is typing.Literal: name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in self._flattened_typing_literal_args(annotation)) - if param.default is not param.empty: # We don't want None or '' to trigger the [name=value] case and instead it should # do [name] since [name=None] or [name=] are not exactly useful for the user. diff --git a/docs/ext/commands/api.rst b/docs/ext/commands/api.rst index 6fa552d62..9c1e38173 100644 --- a/docs/ext/commands/api.rst +++ b/docs/ext/commands/api.rst @@ -323,27 +323,7 @@ Converters .. autoclass:: discord.ext.commands.clean_content :members: -.. data:: ext.commands.Greedy - - A special converter that greedily consumes arguments until it can't. - As a consequence of this behaviour, most input errors are silently discarded, - since it is used as an indicator of when to stop parsing. - - When a parser error is met the greedy converter stops converting, undoes the - internal string parsing routine, and continues parsing regularly. - - For example, in the following code: - - .. code-block:: python3 - - @commands.command() - async def test(ctx, numbers: Greedy[int], reason: str): - await ctx.send(f"numbers: {numbers}, reason: {reason}") - - An invocation of ``[p]test 1 2 3 4 5 6 hello`` would pass ``numbers`` with - ``[1, 2, 3, 4, 5, 6]`` and ``reason`` with ``hello``\. - - For more information, check :ref:`ext_commands_special_converters`. +.. autoclass:: ext.commands.Greedy() .. _ext_commands_api_errors: