Browse Source

[commands] Make `commands.Greedy` a `typing.Generic`

pull/6693/head
James 4 years ago
committed by GitHub
parent
commit
bcd3a00eaf
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 113
      discord/ext/commands/converter.py
  2. 5
      discord/ext/commands/core.py
  3. 22
      docs/ext/commands/api.rst

113
discord/ext/commands/converter.py

@ -26,7 +26,7 @@ from __future__ import annotations
import re
import inspect
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, Union, runtime_checkable
from typing import TYPE_CHECKING, List, Protocol, TypeVar, Tuple, Union, runtime_checkable
import discord
from .errors import *
@ -70,11 +70,12 @@ def _get_from_guilds(bot, getter, argument):
_utils_get = discord.utils.get
T = TypeVar('T', covariant=True)
T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
@runtime_checkable
class Converter(Protocol[T]):
class Converter(Protocol[T_co]):
"""The base class of custom converters that require the :class:`.Context`
to be passed to be useful.
@ -85,7 +86,7 @@ class Converter(Protocol[T]):
method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`.
"""
async def convert(self, ctx: Context, argument: str) -> T:
async def convert(self, ctx: Context, argument: str) -> T_co:
"""|coro|
The method to override to do conversion logic.
@ -110,7 +111,8 @@ class Converter(Protocol[T]):
"""
raise NotImplementedError('Derived classes need to implement this.')
class IDConverter(Converter[T]):
class IDConverter(Converter[T_co]):
def __init__(self):
self._id_regex = re.compile(r'([0-9]{15,20})$')
super().__init__()
@ -118,6 +120,7 @@ class IDConverter(Converter[T]):
def _get_id_match(self, argument):
return self._id_regex.match(argument)
class MemberConverter(IDConverter[discord.Member]):
"""Converts to a :class:`~discord.Member`.
@ -204,6 +207,7 @@ class MemberConverter(IDConverter[discord.Member]):
return result
class UserConverter(IDConverter[discord.User]):
"""Converts to a :class:`~discord.User`.
@ -223,6 +227,7 @@ class UserConverter(IDConverter[discord.User]):
This converter now lazily fetches users from the HTTP APIs if an ID is passed
and it's not available in cache.
"""
async def convert(self, ctx: Context, argument: str) -> discord.User:
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument)
result = None
@ -263,6 +268,7 @@ class UserConverter(IDConverter[discord.User]):
return result
class PartialMessageConverter(Converter[discord.PartialMessage]):
"""Converts to a :class:`discord.PartialMessage`.
@ -274,6 +280,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
2. By message ID (The message is assumed to be in the context channel.)
3. By message URL
"""
@staticmethod
def _get_id_matches(argument):
id_regex = re.compile(r'(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$')
@ -285,8 +292,8 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
match = id_regex.match(argument) or link_regex.match(argument)
if not match:
raise MessageNotFound(argument)
channel_id = match.group("channel_id")
return int(match.group("message_id")), int(channel_id) if channel_id else None
channel_id = match.group('channel_id')
return int(match.group('message_id')), int(channel_id) if channel_id else None
async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage:
message_id, channel_id = self._get_id_matches(argument)
@ -295,6 +302,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
raise ChannelNotFound(channel_id)
return discord.PartialMessage(channel=channel, id=message_id)
class MessageConverter(IDConverter[discord.Message]):
"""Converts to a :class:`discord.Message`.
@ -309,6 +317,7 @@ class MessageConverter(IDConverter[discord.Message]):
.. versionchanged:: 1.5
Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context, argument: str) -> discord.Message:
message_id, channel_id = PartialMessageConverter._get_id_matches(argument)
message = ctx.bot._connection._get_message(message_id)
@ -324,6 +333,7 @@ class MessageConverter(IDConverter[discord.Message]):
except discord.Forbidden:
raise ChannelNotReadable(channel)
class TextChannelConverter(IDConverter[discord.TextChannel]):
"""Converts to a :class:`~discord.TextChannel`.
@ -339,6 +349,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
.. versionchanged:: 1.5
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel:
bot = ctx.bot
@ -351,8 +362,10 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
if guild:
result = discord.utils.get(guild.text_channels, name=argument)
else:
def check(c):
return isinstance(c, discord.TextChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
else:
channel_id = int(match.group(1))
@ -366,6 +379,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
return result
class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
"""Converts to a :class:`~discord.VoiceChannel`.
@ -381,6 +395,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
.. versionchanged:: 1.5
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
@ -392,8 +407,10 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
if guild:
result = discord.utils.get(guild.voice_channels, name=argument)
else:
def check(c):
return isinstance(c, discord.VoiceChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
else:
channel_id = int(match.group(1))
@ -407,6 +424,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
return result
class StageChannelConverter(IDConverter[discord.StageChannel]):
"""Converts to a :class:`~discord.StageChannel`.
@ -421,6 +439,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
2. Lookup by mention.
3. Lookup by name
"""
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
@ -432,8 +451,10 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
if guild:
result = discord.utils.get(guild.stage_channels, name=argument)
else:
def check(c):
return isinstance(c, discord.StageChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
else:
channel_id = int(match.group(1))
@ -447,6 +468,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
return result
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
"""Converts to a :class:`~discord.CategoryChannel`.
@ -462,6 +484,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
.. versionchanged:: 1.5
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel:
bot = ctx.bot
@ -474,8 +497,10 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
if guild:
result = discord.utils.get(guild.categories, name=argument)
else:
def check(c):
return isinstance(c, discord.CategoryChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
else:
channel_id = int(match.group(1))
@ -489,6 +514,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
return result
class StoreChannelConverter(IDConverter[discord.StoreChannel]):
"""Converts to a :class:`~discord.StoreChannel`.
@ -515,8 +541,10 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
if guild:
result = discord.utils.get(guild.channels, name=argument)
else:
def check(c):
return isinstance(c, discord.StoreChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
else:
channel_id = int(match.group(1))
@ -530,6 +558,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
return result
class ColourConverter(Converter[discord.Colour]):
"""Converts to a :class:`~discord.Colour`.
@ -612,8 +641,10 @@ class ColourConverter(Converter[discord.Colour]):
raise BadColourArgument(arg)
return method()
ColorConverter = ColourConverter
class RoleConverter(IDConverter[discord.Role]):
"""Converts to a :class:`~discord.Role`.
@ -629,6 +660,7 @@ class RoleConverter(IDConverter[discord.Role]):
.. versionchanged:: 1.5
Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context, argument: str) -> discord.Role:
guild = ctx.guild
if not guild:
@ -644,11 +676,14 @@ class RoleConverter(IDConverter[discord.Role]):
raise RoleNotFound(argument)
return result
class GameConverter(Converter[discord.Game]):
"""Converts to :class:`~discord.Game`."""
async def convert(self, ctx: Context, argument: str) -> discord.Game:
return discord.Game(name=argument)
class InviteConverter(Converter[discord.Invite]):
"""Converts to a :class:`~discord.Invite`.
@ -657,6 +692,7 @@ class InviteConverter(Converter[discord.Invite]):
.. versionchanged:: 1.5
Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context, argument: str) -> discord.Invite:
try:
invite = await ctx.bot.fetch_invite(argument)
@ -664,6 +700,7 @@ class InviteConverter(Converter[discord.Invite]):
except Exception as exc:
raise BadInviteArgument() from exc
class GuildConverter(IDConverter[discord.Guild]):
"""Converts to a :class:`~discord.Guild`.
@ -690,6 +727,7 @@ class GuildConverter(IDConverter[discord.Guild]):
raise GuildNotFound(argument)
return result
class EmojiConverter(IDConverter[discord.Emoji]):
"""Converts to a :class:`~discord.Emoji`.
@ -705,6 +743,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
.. versionchanged:: 1.5
Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context, argument: str) -> discord.Emoji:
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]+:([0-9]+)>$', argument)
result = None
@ -733,6 +772,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
return result
class PartialEmojiConverter(Converter[discord.PartialEmoji]):
"""Converts to a :class:`~discord.PartialEmoji`.
@ -741,6 +781,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
.. versionchanged:: 1.5
Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji:
match = re.match(r'<(a?):([a-zA-Z0-9\_]+):([0-9]+)>$', argument)
@ -749,11 +790,13 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
emoji_name = match.group(2)
emoji_id = int(match.group(3))
return discord.PartialEmoji.with_state(ctx.bot._connection, animated=emoji_animated, name=emoji_name,
id=emoji_id)
return discord.PartialEmoji.with_state(
ctx.bot._connection, animated=emoji_animated, name=emoji_name, id=emoji_id
)
raise PartialEmojiConversionFailure(argument)
class clean_content(Converter[str]):
"""Converts the argument to mention scrubbed version of
said content.
@ -773,6 +816,7 @@ class clean_content(Converter[str]):
.. versionadded:: 1.7
"""
def __init__(self, *, fix_channel_mentions=False, use_nicknames=True, escape_markdown=False, remove_markdown=False):
self.fix_channel_mentions = fix_channel_mentions
self.use_nicknames = use_nicknames
@ -784,6 +828,7 @@ class clean_content(Converter[str]):
transformations = {}
if self.fix_channel_mentions and ctx.guild:
def resolve_channel(id, *, _get=ctx.guild.get_channel):
ch = _get(id)
return f'<#{id}>', ('#' + ch.name if ch else '#deleted-channel')
@ -791,15 +836,18 @@ class clean_content(Converter[str]):
transformations.update(resolve_channel(channel) for channel in message.raw_channel_mentions)
if self.use_nicknames and ctx.guild:
def resolve_member(id, *, _get=ctx.guild.get_member):
m = _get(id)
return '@' + m.display_name if m else '@deleted-user'
else:
def resolve_member(id, *, _get=ctx.bot.get_user):
m = _get(id)
return '@' + m.name if m else '@deleted-user'
# fmt: off
transformations.update(
(f'<@{member_id}>', resolve_member(member_id))
for member_id in message.raw_mentions
@ -809,8 +857,10 @@ class clean_content(Converter[str]):
(f'<@!{member_id}>', resolve_member(member_id))
for member_id in message.raw_mentions
)
# fmt: on
if ctx.guild:
def resolve_role(_id, *, _find=ctx.guild.get_role):
r = _find(_id)
return '@' + r.name if r else '@deleted-role'
@ -818,7 +868,7 @@ class clean_content(Converter[str]):
transformations.update(
(f'<@&{role_id}>', resolve_role(role_id))
for role_id in message.raw_role_mentions
)
) # fmt: off
def repl(obj):
return transformations.get(obj.group(0), '')
@ -834,28 +884,51 @@ class clean_content(Converter[str]):
# Completely ensure no mentions escape:
return discord.utils.escape_mentions(result)
class _Greedy:
class Greedy(List[T]):
r"""A special converter that greedily consumes arguments until it can't.
As a consequence of this behaviour, most input errors are silently discarded,
since it is used as an indicator of when to stop parsing.
When a parser error is met the greedy converter stops converting, undoes the
internal string parsing routine, and continues parsing regularly.
For example, in the following code:
.. code-block:: python3
@commands.command()
async def test(ctx, numbers: Greedy[int], reason: str):
await ctx.send("numbers: {}, reason: {}".format(numbers, reason))
An invocation of ``[p]test 1 2 3 4 5 6 hello`` would pass ``numbers`` with
``[1, 2, 3, 4, 5, 6]`` and ``reason`` with ``hello``\.
For more information, check :ref:`ext_commands_special_converters`.
"""
__slots__ = ('converter',)
def __init__(self, *, converter=None):
def __init__(self, *, converter: T):
self.converter = converter
def __getitem__(self, params):
def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]:
if not isinstance(params, tuple):
params = (params,)
if len(params) != 1:
raise TypeError('Greedy[...] only takes a single argument')
converter = params[0]
if not (callable(converter) or isinstance(converter, Converter) or hasattr(converter, '__origin__')):
origin = getattr(converter, '__origin__', None)
args = getattr(converter, '__args__', ())
if not (callable(converter) or isinstance(converter, Converter) or origin is not None):
raise TypeError('Greedy[...] expects a type or a Converter instance.')
if converter is str or converter is type(None) or converter is _Greedy:
if converter in (str, type(None)) or origin is Greedy:
raise TypeError(f'Greedy[{converter.__name__}] is invalid.')
if getattr(converter, '__origin__', None) is Union and type(None) in converter.__args__:
if origin is Union and type(None) in args:
raise TypeError(f'Greedy[{converter!r}] is invalid.')
return self.__class__(converter=converter)
Greedy = _Greedy()
return cls(converter=converter)

5
discord/ext/commands/core.py

@ -560,7 +560,7 @@ class Command(_BaseCommand):
# The greedy converter is simple -- it keeps going until it fails in which case,
# it undos the view ready for the next parameter to use instead
if type(converter) is converters._Greedy:
if isinstance(converter, converters.Greedy):
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY:
return await self._transform_greedy_pos(ctx, param, required, converter.converter)
elif param.kind == param.VAR_POSITIONAL:
@ -1042,7 +1042,7 @@ class Command(_BaseCommand):
result = []
for name, param in params.items():
greedy = isinstance(param.annotation, converters._Greedy)
greedy = isinstance(param.annotation, converters.Greedy)
optional = False # postpone evaluation of if it's an optional argument
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
@ -1059,7 +1059,6 @@ class Command(_BaseCommand):
if origin is typing.Literal:
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v)
for v in self._flattened_typing_literal_args(annotation))
if param.default is not param.empty:
# We don't want None or '' to trigger the [name=value] case and instead it should
# do [name] since [name=None] or [name=] are not exactly useful for the user.

22
docs/ext/commands/api.rst

@ -323,27 +323,7 @@ Converters
.. autoclass:: discord.ext.commands.clean_content
:members:
.. data:: ext.commands.Greedy
A special converter that greedily consumes arguments until it can't.
As a consequence of this behaviour, most input errors are silently discarded,
since it is used as an indicator of when to stop parsing.
When a parser error is met the greedy converter stops converting, undoes the
internal string parsing routine, and continues parsing regularly.
For example, in the following code:
.. code-block:: python3
@commands.command()
async def test(ctx, numbers: Greedy[int], reason: str):
await ctx.send(f"numbers: {numbers}, reason: {reason}")
An invocation of ``[p]test 1 2 3 4 5 6 hello`` would pass ``numbers`` with
``[1, 2, 3, 4, 5, 6]`` and ``reason`` with ``hello``\.
For more information, check :ref:`ext_commands_special_converters`.
.. autoclass:: ext.commands.Greedy()
.. _ext_commands_api_errors:

Loading…
Cancel
Save