|
|
@ -26,7 +26,7 @@ from __future__ import annotations |
|
|
|
|
|
|
|
import re |
|
|
|
import inspect |
|
|
|
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, Union, runtime_checkable |
|
|
|
from typing import TYPE_CHECKING, List, Protocol, TypeVar, Tuple, Union, runtime_checkable |
|
|
|
|
|
|
|
import discord |
|
|
|
from .errors import * |
|
|
@ -70,11 +70,12 @@ def _get_from_guilds(bot, getter, argument): |
|
|
|
|
|
|
|
|
|
|
|
_utils_get = discord.utils.get |
|
|
|
T = TypeVar('T', covariant=True) |
|
|
|
T = TypeVar('T') |
|
|
|
T_co = TypeVar('T_co', covariant=True) |
|
|
|
|
|
|
|
|
|
|
|
@runtime_checkable |
|
|
|
class Converter(Protocol[T]): |
|
|
|
class Converter(Protocol[T_co]): |
|
|
|
"""The base class of custom converters that require the :class:`.Context` |
|
|
|
to be passed to be useful. |
|
|
|
|
|
|
@ -85,7 +86,7 @@ class Converter(Protocol[T]): |
|
|
|
method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`. |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> T: |
|
|
|
async def convert(self, ctx: Context, argument: str) -> T_co: |
|
|
|
"""|coro| |
|
|
|
|
|
|
|
The method to override to do conversion logic. |
|
|
@ -110,7 +111,8 @@ class Converter(Protocol[T]): |
|
|
|
""" |
|
|
|
raise NotImplementedError('Derived classes need to implement this.') |
|
|
|
|
|
|
|
class IDConverter(Converter[T]): |
|
|
|
|
|
|
|
class IDConverter(Converter[T_co]): |
|
|
|
def __init__(self): |
|
|
|
self._id_regex = re.compile(r'([0-9]{15,20})$') |
|
|
|
super().__init__() |
|
|
@ -118,6 +120,7 @@ class IDConverter(Converter[T]): |
|
|
|
def _get_id_match(self, argument): |
|
|
|
return self._id_regex.match(argument) |
|
|
|
|
|
|
|
|
|
|
|
class MemberConverter(IDConverter[discord.Member]): |
|
|
|
"""Converts to a :class:`~discord.Member`. |
|
|
|
|
|
|
@ -204,6 +207,7 @@ class MemberConverter(IDConverter[discord.Member]): |
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
class UserConverter(IDConverter[discord.User]): |
|
|
|
"""Converts to a :class:`~discord.User`. |
|
|
|
|
|
|
@ -223,6 +227,7 @@ class UserConverter(IDConverter[discord.User]): |
|
|
|
This converter now lazily fetches users from the HTTP APIs if an ID is passed |
|
|
|
and it's not available in cache. |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.User: |
|
|
|
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument) |
|
|
|
result = None |
|
|
@ -263,6 +268,7 @@ class UserConverter(IDConverter[discord.User]): |
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
class PartialMessageConverter(Converter[discord.PartialMessage]): |
|
|
|
"""Converts to a :class:`discord.PartialMessage`. |
|
|
|
|
|
|
@ -274,6 +280,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): |
|
|
|
2. By message ID (The message is assumed to be in the context channel.) |
|
|
|
3. By message URL |
|
|
|
""" |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _get_id_matches(argument): |
|
|
|
id_regex = re.compile(r'(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$') |
|
|
@ -285,8 +292,8 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): |
|
|
|
match = id_regex.match(argument) or link_regex.match(argument) |
|
|
|
if not match: |
|
|
|
raise MessageNotFound(argument) |
|
|
|
channel_id = match.group("channel_id") |
|
|
|
return int(match.group("message_id")), int(channel_id) if channel_id else None |
|
|
|
channel_id = match.group('channel_id') |
|
|
|
return int(match.group('message_id')), int(channel_id) if channel_id else None |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage: |
|
|
|
message_id, channel_id = self._get_id_matches(argument) |
|
|
@ -295,6 +302,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): |
|
|
|
raise ChannelNotFound(channel_id) |
|
|
|
return discord.PartialMessage(channel=channel, id=message_id) |
|
|
|
|
|
|
|
|
|
|
|
class MessageConverter(IDConverter[discord.Message]): |
|
|
|
"""Converts to a :class:`discord.Message`. |
|
|
|
|
|
|
@ -309,6 +317,7 @@ class MessageConverter(IDConverter[discord.Message]): |
|
|
|
.. versionchanged:: 1.5 |
|
|
|
Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument` |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.Message: |
|
|
|
message_id, channel_id = PartialMessageConverter._get_id_matches(argument) |
|
|
|
message = ctx.bot._connection._get_message(message_id) |
|
|
@ -324,6 +333,7 @@ class MessageConverter(IDConverter[discord.Message]): |
|
|
|
except discord.Forbidden: |
|
|
|
raise ChannelNotReadable(channel) |
|
|
|
|
|
|
|
|
|
|
|
class TextChannelConverter(IDConverter[discord.TextChannel]): |
|
|
|
"""Converts to a :class:`~discord.TextChannel`. |
|
|
|
|
|
|
@ -339,6 +349,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): |
|
|
|
.. versionchanged:: 1.5 |
|
|
|
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: |
|
|
|
bot = ctx.bot |
|
|
|
|
|
|
@ -351,8 +362,10 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): |
|
|
|
if guild: |
|
|
|
result = discord.utils.get(guild.text_channels, name=argument) |
|
|
|
else: |
|
|
|
|
|
|
|
def check(c): |
|
|
|
return isinstance(c, discord.TextChannel) and c.name == argument |
|
|
|
|
|
|
|
result = discord.utils.find(check, bot.get_all_channels()) |
|
|
|
else: |
|
|
|
channel_id = int(match.group(1)) |
|
|
@ -366,6 +379,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): |
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): |
|
|
|
"""Converts to a :class:`~discord.VoiceChannel`. |
|
|
|
|
|
|
@ -381,6 +395,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): |
|
|
|
.. versionchanged:: 1.5 |
|
|
|
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: |
|
|
|
bot = ctx.bot |
|
|
|
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) |
|
|
@ -392,8 +407,10 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): |
|
|
|
if guild: |
|
|
|
result = discord.utils.get(guild.voice_channels, name=argument) |
|
|
|
else: |
|
|
|
|
|
|
|
def check(c): |
|
|
|
return isinstance(c, discord.VoiceChannel) and c.name == argument |
|
|
|
|
|
|
|
result = discord.utils.find(check, bot.get_all_channels()) |
|
|
|
else: |
|
|
|
channel_id = int(match.group(1)) |
|
|
@ -407,6 +424,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): |
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
class StageChannelConverter(IDConverter[discord.StageChannel]): |
|
|
|
"""Converts to a :class:`~discord.StageChannel`. |
|
|
|
|
|
|
@ -421,6 +439,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): |
|
|
|
2. Lookup by mention. |
|
|
|
3. Lookup by name |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: |
|
|
|
bot = ctx.bot |
|
|
|
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) |
|
|
@ -432,8 +451,10 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): |
|
|
|
if guild: |
|
|
|
result = discord.utils.get(guild.stage_channels, name=argument) |
|
|
|
else: |
|
|
|
|
|
|
|
def check(c): |
|
|
|
return isinstance(c, discord.StageChannel) and c.name == argument |
|
|
|
|
|
|
|
result = discord.utils.find(check, bot.get_all_channels()) |
|
|
|
else: |
|
|
|
channel_id = int(match.group(1)) |
|
|
@ -447,6 +468,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): |
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): |
|
|
|
"""Converts to a :class:`~discord.CategoryChannel`. |
|
|
|
|
|
|
@ -462,6 +484,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): |
|
|
|
.. versionchanged:: 1.5 |
|
|
|
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: |
|
|
|
bot = ctx.bot |
|
|
|
|
|
|
@ -474,8 +497,10 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): |
|
|
|
if guild: |
|
|
|
result = discord.utils.get(guild.categories, name=argument) |
|
|
|
else: |
|
|
|
|
|
|
|
def check(c): |
|
|
|
return isinstance(c, discord.CategoryChannel) and c.name == argument |
|
|
|
|
|
|
|
result = discord.utils.find(check, bot.get_all_channels()) |
|
|
|
else: |
|
|
|
channel_id = int(match.group(1)) |
|
|
@ -489,6 +514,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): |
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
class StoreChannelConverter(IDConverter[discord.StoreChannel]): |
|
|
|
"""Converts to a :class:`~discord.StoreChannel`. |
|
|
|
|
|
|
@ -515,8 +541,10 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): |
|
|
|
if guild: |
|
|
|
result = discord.utils.get(guild.channels, name=argument) |
|
|
|
else: |
|
|
|
|
|
|
|
def check(c): |
|
|
|
return isinstance(c, discord.StoreChannel) and c.name == argument |
|
|
|
|
|
|
|
result = discord.utils.find(check, bot.get_all_channels()) |
|
|
|
else: |
|
|
|
channel_id = int(match.group(1)) |
|
|
@ -530,6 +558,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): |
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
class ColourConverter(Converter[discord.Colour]): |
|
|
|
"""Converts to a :class:`~discord.Colour`. |
|
|
|
|
|
|
@ -612,8 +641,10 @@ class ColourConverter(Converter[discord.Colour]): |
|
|
|
raise BadColourArgument(arg) |
|
|
|
return method() |
|
|
|
|
|
|
|
|
|
|
|
ColorConverter = ColourConverter |
|
|
|
|
|
|
|
|
|
|
|
class RoleConverter(IDConverter[discord.Role]): |
|
|
|
"""Converts to a :class:`~discord.Role`. |
|
|
|
|
|
|
@ -629,6 +660,7 @@ class RoleConverter(IDConverter[discord.Role]): |
|
|
|
.. versionchanged:: 1.5 |
|
|
|
Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument` |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.Role: |
|
|
|
guild = ctx.guild |
|
|
|
if not guild: |
|
|
@ -644,11 +676,14 @@ class RoleConverter(IDConverter[discord.Role]): |
|
|
|
raise RoleNotFound(argument) |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
class GameConverter(Converter[discord.Game]): |
|
|
|
"""Converts to :class:`~discord.Game`.""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.Game: |
|
|
|
return discord.Game(name=argument) |
|
|
|
|
|
|
|
|
|
|
|
class InviteConverter(Converter[discord.Invite]): |
|
|
|
"""Converts to a :class:`~discord.Invite`. |
|
|
|
|
|
|
@ -657,6 +692,7 @@ class InviteConverter(Converter[discord.Invite]): |
|
|
|
.. versionchanged:: 1.5 |
|
|
|
Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument` |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.Invite: |
|
|
|
try: |
|
|
|
invite = await ctx.bot.fetch_invite(argument) |
|
|
@ -664,6 +700,7 @@ class InviteConverter(Converter[discord.Invite]): |
|
|
|
except Exception as exc: |
|
|
|
raise BadInviteArgument() from exc |
|
|
|
|
|
|
|
|
|
|
|
class GuildConverter(IDConverter[discord.Guild]): |
|
|
|
"""Converts to a :class:`~discord.Guild`. |
|
|
|
|
|
|
@ -690,6 +727,7 @@ class GuildConverter(IDConverter[discord.Guild]): |
|
|
|
raise GuildNotFound(argument) |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
class EmojiConverter(IDConverter[discord.Emoji]): |
|
|
|
"""Converts to a :class:`~discord.Emoji`. |
|
|
|
|
|
|
@ -705,6 +743,7 @@ class EmojiConverter(IDConverter[discord.Emoji]): |
|
|
|
.. versionchanged:: 1.5 |
|
|
|
Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument` |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.Emoji: |
|
|
|
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]+:([0-9]+)>$', argument) |
|
|
|
result = None |
|
|
@ -733,6 +772,7 @@ class EmojiConverter(IDConverter[discord.Emoji]): |
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
class PartialEmojiConverter(Converter[discord.PartialEmoji]): |
|
|
|
"""Converts to a :class:`~discord.PartialEmoji`. |
|
|
|
|
|
|
@ -741,6 +781,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]): |
|
|
|
.. versionchanged:: 1.5 |
|
|
|
Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument` |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji: |
|
|
|
match = re.match(r'<(a?):([a-zA-Z0-9\_]+):([0-9]+)>$', argument) |
|
|
|
|
|
|
@ -749,11 +790,13 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]): |
|
|
|
emoji_name = match.group(2) |
|
|
|
emoji_id = int(match.group(3)) |
|
|
|
|
|
|
|
return discord.PartialEmoji.with_state(ctx.bot._connection, animated=emoji_animated, name=emoji_name, |
|
|
|
id=emoji_id) |
|
|
|
return discord.PartialEmoji.with_state( |
|
|
|
ctx.bot._connection, animated=emoji_animated, name=emoji_name, id=emoji_id |
|
|
|
) |
|
|
|
|
|
|
|
raise PartialEmojiConversionFailure(argument) |
|
|
|
|
|
|
|
|
|
|
|
class clean_content(Converter[str]): |
|
|
|
"""Converts the argument to mention scrubbed version of |
|
|
|
said content. |
|
|
@ -773,6 +816,7 @@ class clean_content(Converter[str]): |
|
|
|
|
|
|
|
.. versionadded:: 1.7 |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, *, fix_channel_mentions=False, use_nicknames=True, escape_markdown=False, remove_markdown=False): |
|
|
|
self.fix_channel_mentions = fix_channel_mentions |
|
|
|
self.use_nicknames = use_nicknames |
|
|
@ -784,6 +828,7 @@ class clean_content(Converter[str]): |
|
|
|
transformations = {} |
|
|
|
|
|
|
|
if self.fix_channel_mentions and ctx.guild: |
|
|
|
|
|
|
|
def resolve_channel(id, *, _get=ctx.guild.get_channel): |
|
|
|
ch = _get(id) |
|
|
|
return f'<#{id}>', ('#' + ch.name if ch else '#deleted-channel') |
|
|
@ -791,15 +836,18 @@ class clean_content(Converter[str]): |
|
|
|
transformations.update(resolve_channel(channel) for channel in message.raw_channel_mentions) |
|
|
|
|
|
|
|
if self.use_nicknames and ctx.guild: |
|
|
|
|
|
|
|
def resolve_member(id, *, _get=ctx.guild.get_member): |
|
|
|
m = _get(id) |
|
|
|
return '@' + m.display_name if m else '@deleted-user' |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
def resolve_member(id, *, _get=ctx.bot.get_user): |
|
|
|
m = _get(id) |
|
|
|
return '@' + m.name if m else '@deleted-user' |
|
|
|
|
|
|
|
|
|
|
|
# fmt: off |
|
|
|
transformations.update( |
|
|
|
(f'<@{member_id}>', resolve_member(member_id)) |
|
|
|
for member_id in message.raw_mentions |
|
|
@ -809,8 +857,10 @@ class clean_content(Converter[str]): |
|
|
|
(f'<@!{member_id}>', resolve_member(member_id)) |
|
|
|
for member_id in message.raw_mentions |
|
|
|
) |
|
|
|
# fmt: on |
|
|
|
|
|
|
|
if ctx.guild: |
|
|
|
|
|
|
|
def resolve_role(_id, *, _find=ctx.guild.get_role): |
|
|
|
r = _find(_id) |
|
|
|
return '@' + r.name if r else '@deleted-role' |
|
|
@ -818,7 +868,7 @@ class clean_content(Converter[str]): |
|
|
|
transformations.update( |
|
|
|
(f'<@&{role_id}>', resolve_role(role_id)) |
|
|
|
for role_id in message.raw_role_mentions |
|
|
|
) |
|
|
|
) # fmt: off |
|
|
|
|
|
|
|
def repl(obj): |
|
|
|
return transformations.get(obj.group(0), '') |
|
|
@ -834,28 +884,51 @@ class clean_content(Converter[str]): |
|
|
|
# Completely ensure no mentions escape: |
|
|
|
return discord.utils.escape_mentions(result) |
|
|
|
|
|
|
|
class _Greedy: |
|
|
|
|
|
|
|
class Greedy(List[T]): |
|
|
|
r"""A special converter that greedily consumes arguments until it can't. |
|
|
|
As a consequence of this behaviour, most input errors are silently discarded, |
|
|
|
since it is used as an indicator of when to stop parsing. |
|
|
|
|
|
|
|
When a parser error is met the greedy converter stops converting, undoes the |
|
|
|
internal string parsing routine, and continues parsing regularly. |
|
|
|
|
|
|
|
For example, in the following code: |
|
|
|
|
|
|
|
.. code-block:: python3 |
|
|
|
|
|
|
|
@commands.command() |
|
|
|
async def test(ctx, numbers: Greedy[int], reason: str): |
|
|
|
await ctx.send("numbers: {}, reason: {}".format(numbers, reason)) |
|
|
|
|
|
|
|
An invocation of ``[p]test 1 2 3 4 5 6 hello`` would pass ``numbers`` with |
|
|
|
``[1, 2, 3, 4, 5, 6]`` and ``reason`` with ``hello``\. |
|
|
|
|
|
|
|
For more information, check :ref:`ext_commands_special_converters`. |
|
|
|
""" |
|
|
|
|
|
|
|
__slots__ = ('converter',) |
|
|
|
|
|
|
|
def __init__(self, *, converter=None): |
|
|
|
def __init__(self, *, converter: T): |
|
|
|
self.converter = converter |
|
|
|
|
|
|
|
def __getitem__(self, params): |
|
|
|
def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]: |
|
|
|
if not isinstance(params, tuple): |
|
|
|
params = (params,) |
|
|
|
if len(params) != 1: |
|
|
|
raise TypeError('Greedy[...] only takes a single argument') |
|
|
|
converter = params[0] |
|
|
|
|
|
|
|
if not (callable(converter) or isinstance(converter, Converter) or hasattr(converter, '__origin__')): |
|
|
|
origin = getattr(converter, '__origin__', None) |
|
|
|
args = getattr(converter, '__args__', ()) |
|
|
|
|
|
|
|
if not (callable(converter) or isinstance(converter, Converter) or origin is not None): |
|
|
|
raise TypeError('Greedy[...] expects a type or a Converter instance.') |
|
|
|
|
|
|
|
if converter is str or converter is type(None) or converter is _Greedy: |
|
|
|
if converter in (str, type(None)) or origin is Greedy: |
|
|
|
raise TypeError(f'Greedy[{converter.__name__}] is invalid.') |
|
|
|
|
|
|
|
if getattr(converter, '__origin__', None) is Union and type(None) in converter.__args__: |
|
|
|
if origin is Union and type(None) in args: |
|
|
|
raise TypeError(f'Greedy[{converter!r}] is invalid.') |
|
|
|
|
|
|
|
return self.__class__(converter=converter) |
|
|
|
|
|
|
|
Greedy = _Greedy() |
|
|
|
return cls(converter=converter) |
|
|
|