Browse Source

[commands] Make `commands.Greedy` a `typing.Generic`

pull/6693/head
James 4 years ago
committed by GitHub
parent
commit
bcd3a00eaf
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 113
      discord/ext/commands/converter.py
  2. 5
      discord/ext/commands/core.py
  3. 22
      docs/ext/commands/api.rst

113
discord/ext/commands/converter.py

@ -26,7 +26,7 @@ from __future__ import annotations
import re import re
import inspect import inspect
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, Union, runtime_checkable from typing import TYPE_CHECKING, List, Protocol, TypeVar, Tuple, Union, runtime_checkable
import discord import discord
from .errors import * from .errors import *
@ -70,11 +70,12 @@ def _get_from_guilds(bot, getter, argument):
_utils_get = discord.utils.get _utils_get = discord.utils.get
T = TypeVar('T', covariant=True) T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
@runtime_checkable @runtime_checkable
class Converter(Protocol[T]): class Converter(Protocol[T_co]):
"""The base class of custom converters that require the :class:`.Context` """The base class of custom converters that require the :class:`.Context`
to be passed to be useful. to be passed to be useful.
@ -85,7 +86,7 @@ class Converter(Protocol[T]):
method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`. method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`.
""" """
async def convert(self, ctx: Context, argument: str) -> T: async def convert(self, ctx: Context, argument: str) -> T_co:
"""|coro| """|coro|
The method to override to do conversion logic. The method to override to do conversion logic.
@ -110,7 +111,8 @@ class Converter(Protocol[T]):
""" """
raise NotImplementedError('Derived classes need to implement this.') raise NotImplementedError('Derived classes need to implement this.')
class IDConverter(Converter[T]):
class IDConverter(Converter[T_co]):
def __init__(self): def __init__(self):
self._id_regex = re.compile(r'([0-9]{15,20})$') self._id_regex = re.compile(r'([0-9]{15,20})$')
super().__init__() super().__init__()
@ -118,6 +120,7 @@ class IDConverter(Converter[T]):
def _get_id_match(self, argument): def _get_id_match(self, argument):
return self._id_regex.match(argument) return self._id_regex.match(argument)
class MemberConverter(IDConverter[discord.Member]): class MemberConverter(IDConverter[discord.Member]):
"""Converts to a :class:`~discord.Member`. """Converts to a :class:`~discord.Member`.
@ -204,6 +207,7 @@ class MemberConverter(IDConverter[discord.Member]):
return result return result
class UserConverter(IDConverter[discord.User]): class UserConverter(IDConverter[discord.User]):
"""Converts to a :class:`~discord.User`. """Converts to a :class:`~discord.User`.
@ -223,6 +227,7 @@ class UserConverter(IDConverter[discord.User]):
This converter now lazily fetches users from the HTTP APIs if an ID is passed This converter now lazily fetches users from the HTTP APIs if an ID is passed
and it's not available in cache. and it's not available in cache.
""" """
async def convert(self, ctx: Context, argument: str) -> discord.User: async def convert(self, ctx: Context, argument: str) -> discord.User:
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument) match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument)
result = None result = None
@ -263,6 +268,7 @@ class UserConverter(IDConverter[discord.User]):
return result return result
class PartialMessageConverter(Converter[discord.PartialMessage]): class PartialMessageConverter(Converter[discord.PartialMessage]):
"""Converts to a :class:`discord.PartialMessage`. """Converts to a :class:`discord.PartialMessage`.
@ -274,6 +280,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
2. By message ID (The message is assumed to be in the context channel.) 2. By message ID (The message is assumed to be in the context channel.)
3. By message URL 3. By message URL
""" """
@staticmethod @staticmethod
def _get_id_matches(argument): def _get_id_matches(argument):
id_regex = re.compile(r'(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$') id_regex = re.compile(r'(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$')
@ -285,8 +292,8 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
match = id_regex.match(argument) or link_regex.match(argument) match = id_regex.match(argument) or link_regex.match(argument)
if not match: if not match:
raise MessageNotFound(argument) raise MessageNotFound(argument)
channel_id = match.group("channel_id") channel_id = match.group('channel_id')
return int(match.group("message_id")), int(channel_id) if channel_id else None return int(match.group('message_id')), int(channel_id) if channel_id else None
async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage: async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage:
message_id, channel_id = self._get_id_matches(argument) message_id, channel_id = self._get_id_matches(argument)
@ -295,6 +302,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
raise ChannelNotFound(channel_id) raise ChannelNotFound(channel_id)
return discord.PartialMessage(channel=channel, id=message_id) return discord.PartialMessage(channel=channel, id=message_id)
class MessageConverter(IDConverter[discord.Message]): class MessageConverter(IDConverter[discord.Message]):
"""Converts to a :class:`discord.Message`. """Converts to a :class:`discord.Message`.
@ -309,6 +317,7 @@ class MessageConverter(IDConverter[discord.Message]):
.. versionchanged:: 1.5 .. versionchanged:: 1.5
Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument` Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context, argument: str) -> discord.Message: async def convert(self, ctx: Context, argument: str) -> discord.Message:
message_id, channel_id = PartialMessageConverter._get_id_matches(argument) message_id, channel_id = PartialMessageConverter._get_id_matches(argument)
message = ctx.bot._connection._get_message(message_id) message = ctx.bot._connection._get_message(message_id)
@ -324,6 +333,7 @@ class MessageConverter(IDConverter[discord.Message]):
except discord.Forbidden: except discord.Forbidden:
raise ChannelNotReadable(channel) raise ChannelNotReadable(channel)
class TextChannelConverter(IDConverter[discord.TextChannel]): class TextChannelConverter(IDConverter[discord.TextChannel]):
"""Converts to a :class:`~discord.TextChannel`. """Converts to a :class:`~discord.TextChannel`.
@ -339,6 +349,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
.. versionchanged:: 1.5 .. versionchanged:: 1.5
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: async def convert(self, ctx: Context, argument: str) -> discord.TextChannel:
bot = ctx.bot bot = ctx.bot
@ -351,8 +362,10 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
if guild: if guild:
result = discord.utils.get(guild.text_channels, name=argument) result = discord.utils.get(guild.text_channels, name=argument)
else: else:
def check(c): def check(c):
return isinstance(c, discord.TextChannel) and c.name == argument return isinstance(c, discord.TextChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels()) result = discord.utils.find(check, bot.get_all_channels())
else: else:
channel_id = int(match.group(1)) channel_id = int(match.group(1))
@ -366,6 +379,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
return result return result
class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
"""Converts to a :class:`~discord.VoiceChannel`. """Converts to a :class:`~discord.VoiceChannel`.
@ -381,6 +395,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
.. versionchanged:: 1.5 .. versionchanged:: 1.5
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel:
bot = ctx.bot bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
@ -392,8 +407,10 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
if guild: if guild:
result = discord.utils.get(guild.voice_channels, name=argument) result = discord.utils.get(guild.voice_channels, name=argument)
else: else:
def check(c): def check(c):
return isinstance(c, discord.VoiceChannel) and c.name == argument return isinstance(c, discord.VoiceChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels()) result = discord.utils.find(check, bot.get_all_channels())
else: else:
channel_id = int(match.group(1)) channel_id = int(match.group(1))
@ -407,6 +424,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
return result return result
class StageChannelConverter(IDConverter[discord.StageChannel]): class StageChannelConverter(IDConverter[discord.StageChannel]):
"""Converts to a :class:`~discord.StageChannel`. """Converts to a :class:`~discord.StageChannel`.
@ -421,6 +439,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
2. Lookup by mention. 2. Lookup by mention.
3. Lookup by name 3. Lookup by name
""" """
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: async def convert(self, ctx: Context, argument: str) -> discord.StageChannel:
bot = ctx.bot bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
@ -432,8 +451,10 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
if guild: if guild:
result = discord.utils.get(guild.stage_channels, name=argument) result = discord.utils.get(guild.stage_channels, name=argument)
else: else:
def check(c): def check(c):
return isinstance(c, discord.StageChannel) and c.name == argument return isinstance(c, discord.StageChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels()) result = discord.utils.find(check, bot.get_all_channels())
else: else:
channel_id = int(match.group(1)) channel_id = int(match.group(1))
@ -447,6 +468,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
return result return result
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
"""Converts to a :class:`~discord.CategoryChannel`. """Converts to a :class:`~discord.CategoryChannel`.
@ -462,6 +484,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
.. versionchanged:: 1.5 .. versionchanged:: 1.5
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel:
bot = ctx.bot bot = ctx.bot
@ -474,8 +497,10 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
if guild: if guild:
result = discord.utils.get(guild.categories, name=argument) result = discord.utils.get(guild.categories, name=argument)
else: else:
def check(c): def check(c):
return isinstance(c, discord.CategoryChannel) and c.name == argument return isinstance(c, discord.CategoryChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels()) result = discord.utils.find(check, bot.get_all_channels())
else: else:
channel_id = int(match.group(1)) channel_id = int(match.group(1))
@ -489,6 +514,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
return result return result
class StoreChannelConverter(IDConverter[discord.StoreChannel]): class StoreChannelConverter(IDConverter[discord.StoreChannel]):
"""Converts to a :class:`~discord.StoreChannel`. """Converts to a :class:`~discord.StoreChannel`.
@ -515,8 +541,10 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
if guild: if guild:
result = discord.utils.get(guild.channels, name=argument) result = discord.utils.get(guild.channels, name=argument)
else: else:
def check(c): def check(c):
return isinstance(c, discord.StoreChannel) and c.name == argument return isinstance(c, discord.StoreChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels()) result = discord.utils.find(check, bot.get_all_channels())
else: else:
channel_id = int(match.group(1)) channel_id = int(match.group(1))
@ -530,6 +558,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
return result return result
class ColourConverter(Converter[discord.Colour]): class ColourConverter(Converter[discord.Colour]):
"""Converts to a :class:`~discord.Colour`. """Converts to a :class:`~discord.Colour`.
@ -612,8 +641,10 @@ class ColourConverter(Converter[discord.Colour]):
raise BadColourArgument(arg) raise BadColourArgument(arg)
return method() return method()
ColorConverter = ColourConverter ColorConverter = ColourConverter
class RoleConverter(IDConverter[discord.Role]): class RoleConverter(IDConverter[discord.Role]):
"""Converts to a :class:`~discord.Role`. """Converts to a :class:`~discord.Role`.
@ -629,6 +660,7 @@ class RoleConverter(IDConverter[discord.Role]):
.. versionchanged:: 1.5 .. versionchanged:: 1.5
Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument` Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context, argument: str) -> discord.Role: async def convert(self, ctx: Context, argument: str) -> discord.Role:
guild = ctx.guild guild = ctx.guild
if not guild: if not guild:
@ -644,11 +676,14 @@ class RoleConverter(IDConverter[discord.Role]):
raise RoleNotFound(argument) raise RoleNotFound(argument)
return result return result
class GameConverter(Converter[discord.Game]): class GameConverter(Converter[discord.Game]):
"""Converts to :class:`~discord.Game`.""" """Converts to :class:`~discord.Game`."""
async def convert(self, ctx: Context, argument: str) -> discord.Game: async def convert(self, ctx: Context, argument: str) -> discord.Game:
return discord.Game(name=argument) return discord.Game(name=argument)
class InviteConverter(Converter[discord.Invite]): class InviteConverter(Converter[discord.Invite]):
"""Converts to a :class:`~discord.Invite`. """Converts to a :class:`~discord.Invite`.
@ -657,6 +692,7 @@ class InviteConverter(Converter[discord.Invite]):
.. versionchanged:: 1.5 .. versionchanged:: 1.5
Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument` Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context, argument: str) -> discord.Invite: async def convert(self, ctx: Context, argument: str) -> discord.Invite:
try: try:
invite = await ctx.bot.fetch_invite(argument) invite = await ctx.bot.fetch_invite(argument)
@ -664,6 +700,7 @@ class InviteConverter(Converter[discord.Invite]):
except Exception as exc: except Exception as exc:
raise BadInviteArgument() from exc raise BadInviteArgument() from exc
class GuildConverter(IDConverter[discord.Guild]): class GuildConverter(IDConverter[discord.Guild]):
"""Converts to a :class:`~discord.Guild`. """Converts to a :class:`~discord.Guild`.
@ -690,6 +727,7 @@ class GuildConverter(IDConverter[discord.Guild]):
raise GuildNotFound(argument) raise GuildNotFound(argument)
return result return result
class EmojiConverter(IDConverter[discord.Emoji]): class EmojiConverter(IDConverter[discord.Emoji]):
"""Converts to a :class:`~discord.Emoji`. """Converts to a :class:`~discord.Emoji`.
@ -705,6 +743,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
.. versionchanged:: 1.5 .. versionchanged:: 1.5
Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument` Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context, argument: str) -> discord.Emoji: async def convert(self, ctx: Context, argument: str) -> discord.Emoji:
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]+:([0-9]+)>$', argument) match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]+:([0-9]+)>$', argument)
result = None result = None
@ -733,6 +772,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
return result return result
class PartialEmojiConverter(Converter[discord.PartialEmoji]): class PartialEmojiConverter(Converter[discord.PartialEmoji]):
"""Converts to a :class:`~discord.PartialEmoji`. """Converts to a :class:`~discord.PartialEmoji`.
@ -741,6 +781,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
.. versionchanged:: 1.5 .. versionchanged:: 1.5
Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument` Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji: async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji:
match = re.match(r'<(a?):([a-zA-Z0-9\_]+):([0-9]+)>$', argument) match = re.match(r'<(a?):([a-zA-Z0-9\_]+):([0-9]+)>$', argument)
@ -749,11 +790,13 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
emoji_name = match.group(2) emoji_name = match.group(2)
emoji_id = int(match.group(3)) emoji_id = int(match.group(3))
return discord.PartialEmoji.with_state(ctx.bot._connection, animated=emoji_animated, name=emoji_name, return discord.PartialEmoji.with_state(
id=emoji_id) ctx.bot._connection, animated=emoji_animated, name=emoji_name, id=emoji_id
)
raise PartialEmojiConversionFailure(argument) raise PartialEmojiConversionFailure(argument)
class clean_content(Converter[str]): class clean_content(Converter[str]):
"""Converts the argument to mention scrubbed version of """Converts the argument to mention scrubbed version of
said content. said content.
@ -773,6 +816,7 @@ class clean_content(Converter[str]):
.. versionadded:: 1.7 .. versionadded:: 1.7
""" """
def __init__(self, *, fix_channel_mentions=False, use_nicknames=True, escape_markdown=False, remove_markdown=False): def __init__(self, *, fix_channel_mentions=False, use_nicknames=True, escape_markdown=False, remove_markdown=False):
self.fix_channel_mentions = fix_channel_mentions self.fix_channel_mentions = fix_channel_mentions
self.use_nicknames = use_nicknames self.use_nicknames = use_nicknames
@ -784,6 +828,7 @@ class clean_content(Converter[str]):
transformations = {} transformations = {}
if self.fix_channel_mentions and ctx.guild: if self.fix_channel_mentions and ctx.guild:
def resolve_channel(id, *, _get=ctx.guild.get_channel): def resolve_channel(id, *, _get=ctx.guild.get_channel):
ch = _get(id) ch = _get(id)
return f'<#{id}>', ('#' + ch.name if ch else '#deleted-channel') return f'<#{id}>', ('#' + ch.name if ch else '#deleted-channel')
@ -791,15 +836,18 @@ class clean_content(Converter[str]):
transformations.update(resolve_channel(channel) for channel in message.raw_channel_mentions) transformations.update(resolve_channel(channel) for channel in message.raw_channel_mentions)
if self.use_nicknames and ctx.guild: if self.use_nicknames and ctx.guild:
def resolve_member(id, *, _get=ctx.guild.get_member): def resolve_member(id, *, _get=ctx.guild.get_member):
m = _get(id) m = _get(id)
return '@' + m.display_name if m else '@deleted-user' return '@' + m.display_name if m else '@deleted-user'
else: else:
def resolve_member(id, *, _get=ctx.bot.get_user): def resolve_member(id, *, _get=ctx.bot.get_user):
m = _get(id) m = _get(id)
return '@' + m.name if m else '@deleted-user' return '@' + m.name if m else '@deleted-user'
# fmt: off
transformations.update( transformations.update(
(f'<@{member_id}>', resolve_member(member_id)) (f'<@{member_id}>', resolve_member(member_id))
for member_id in message.raw_mentions for member_id in message.raw_mentions
@ -809,8 +857,10 @@ class clean_content(Converter[str]):
(f'<@!{member_id}>', resolve_member(member_id)) (f'<@!{member_id}>', resolve_member(member_id))
for member_id in message.raw_mentions for member_id in message.raw_mentions
) )
# fmt: on
if ctx.guild: if ctx.guild:
def resolve_role(_id, *, _find=ctx.guild.get_role): def resolve_role(_id, *, _find=ctx.guild.get_role):
r = _find(_id) r = _find(_id)
return '@' + r.name if r else '@deleted-role' return '@' + r.name if r else '@deleted-role'
@ -818,7 +868,7 @@ class clean_content(Converter[str]):
transformations.update( transformations.update(
(f'<@&{role_id}>', resolve_role(role_id)) (f'<@&{role_id}>', resolve_role(role_id))
for role_id in message.raw_role_mentions for role_id in message.raw_role_mentions
) ) # fmt: off
def repl(obj): def repl(obj):
return transformations.get(obj.group(0), '') return transformations.get(obj.group(0), '')
@ -834,28 +884,51 @@ class clean_content(Converter[str]):
# Completely ensure no mentions escape: # Completely ensure no mentions escape:
return discord.utils.escape_mentions(result) return discord.utils.escape_mentions(result)
class _Greedy:
class Greedy(List[T]):
r"""A special converter that greedily consumes arguments until it can't.
As a consequence of this behaviour, most input errors are silently discarded,
since it is used as an indicator of when to stop parsing.
When a parser error is met the greedy converter stops converting, undoes the
internal string parsing routine, and continues parsing regularly.
For example, in the following code:
.. code-block:: python3
@commands.command()
async def test(ctx, numbers: Greedy[int], reason: str):
await ctx.send("numbers: {}, reason: {}".format(numbers, reason))
An invocation of ``[p]test 1 2 3 4 5 6 hello`` would pass ``numbers`` with
``[1, 2, 3, 4, 5, 6]`` and ``reason`` with ``hello``\.
For more information, check :ref:`ext_commands_special_converters`.
"""
__slots__ = ('converter',) __slots__ = ('converter',)
def __init__(self, *, converter=None): def __init__(self, *, converter: T):
self.converter = converter self.converter = converter
def __getitem__(self, params): def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]:
if not isinstance(params, tuple): if not isinstance(params, tuple):
params = (params,) params = (params,)
if len(params) != 1: if len(params) != 1:
raise TypeError('Greedy[...] only takes a single argument') raise TypeError('Greedy[...] only takes a single argument')
converter = params[0] converter = params[0]
if not (callable(converter) or isinstance(converter, Converter) or hasattr(converter, '__origin__')): origin = getattr(converter, '__origin__', None)
args = getattr(converter, '__args__', ())
if not (callable(converter) or isinstance(converter, Converter) or origin is not None):
raise TypeError('Greedy[...] expects a type or a Converter instance.') raise TypeError('Greedy[...] expects a type or a Converter instance.')
if converter is str or converter is type(None) or converter is _Greedy: if converter in (str, type(None)) or origin is Greedy:
raise TypeError(f'Greedy[{converter.__name__}] is invalid.') raise TypeError(f'Greedy[{converter.__name__}] is invalid.')
if getattr(converter, '__origin__', None) is Union and type(None) in converter.__args__: if origin is Union and type(None) in args:
raise TypeError(f'Greedy[{converter!r}] is invalid.') raise TypeError(f'Greedy[{converter!r}] is invalid.')
return self.__class__(converter=converter) return cls(converter=converter)
Greedy = _Greedy()

5
discord/ext/commands/core.py

@ -560,7 +560,7 @@ class Command(_BaseCommand):
# The greedy converter is simple -- it keeps going until it fails in which case, # The greedy converter is simple -- it keeps going until it fails in which case,
# it undos the view ready for the next parameter to use instead # it undos the view ready for the next parameter to use instead
if type(converter) is converters._Greedy: if isinstance(converter, converters.Greedy):
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY: if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY:
return await self._transform_greedy_pos(ctx, param, required, converter.converter) return await self._transform_greedy_pos(ctx, param, required, converter.converter)
elif param.kind == param.VAR_POSITIONAL: elif param.kind == param.VAR_POSITIONAL:
@ -1042,7 +1042,7 @@ class Command(_BaseCommand):
result = [] result = []
for name, param in params.items(): for name, param in params.items():
greedy = isinstance(param.annotation, converters._Greedy) greedy = isinstance(param.annotation, converters.Greedy)
optional = False # postpone evaluation of if it's an optional argument optional = False # postpone evaluation of if it's an optional argument
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
@ -1059,7 +1059,6 @@ class Command(_BaseCommand):
if origin is typing.Literal: if origin is typing.Literal:
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v)
for v in self._flattened_typing_literal_args(annotation)) for v in self._flattened_typing_literal_args(annotation))
if param.default is not param.empty: if param.default is not param.empty:
# We don't want None or '' to trigger the [name=value] case and instead it should # We don't want None or '' to trigger the [name=value] case and instead it should
# do [name] since [name=None] or [name=] are not exactly useful for the user. # do [name] since [name=None] or [name=] are not exactly useful for the user.

22
docs/ext/commands/api.rst

@ -323,27 +323,7 @@ Converters
.. autoclass:: discord.ext.commands.clean_content .. autoclass:: discord.ext.commands.clean_content
:members: :members:
.. data:: ext.commands.Greedy .. autoclass:: ext.commands.Greedy()
A special converter that greedily consumes arguments until it can't.
As a consequence of this behaviour, most input errors are silently discarded,
since it is used as an indicator of when to stop parsing.
When a parser error is met the greedy converter stops converting, undoes the
internal string parsing routine, and continues parsing regularly.
For example, in the following code:
.. code-block:: python3
@commands.command()
async def test(ctx, numbers: Greedy[int], reason: str):
await ctx.send(f"numbers: {numbers}, reason: {reason}")
An invocation of ``[p]test 1 2 3 4 5 6 hello`` would pass ``numbers`` with
``[1, 2, 3, 4, 5, 6]`` and ``reason`` with ``hello``\.
For more information, check :ref:`ext_commands_special_converters`.
.. _ext_commands_api_errors: .. _ext_commands_api_errors:

Loading…
Cancel
Save