|
|
@ -26,7 +26,7 @@ from __future__ import annotations |
|
|
|
|
|
|
|
import re |
|
|
|
import inspect |
|
|
|
from typing import TYPE_CHECKING, List, Protocol, TypeVar, Tuple, Union, runtime_checkable |
|
|
|
from typing import Iterable, Optional, TYPE_CHECKING, List, Protocol, Type, TypeVar, Tuple, Union, runtime_checkable |
|
|
|
|
|
|
|
import discord |
|
|
|
from .errors import * |
|
|
@ -72,6 +72,7 @@ def _get_from_guilds(bot, getter, argument): |
|
|
|
_utils_get = discord.utils.get |
|
|
|
T = TypeVar('T') |
|
|
|
T_co = TypeVar('T_co', covariant=True) |
|
|
|
CT = TypeVar('CT', bound=discord.abc.GuildChannel) |
|
|
|
|
|
|
|
|
|
|
|
@runtime_checkable |
|
|
@ -112,13 +113,13 @@ class Converter(Protocol[T_co]): |
|
|
|
raise NotImplementedError('Derived classes need to implement this.') |
|
|
|
|
|
|
|
|
|
|
|
class IDConverter(Converter[T_co]): |
|
|
|
def __init__(self): |
|
|
|
self._id_regex = re.compile(r'([0-9]{15,20})$') |
|
|
|
super().__init__() |
|
|
|
_ID_REGEX = re.compile(r'([0-9]{15,20})$') |
|
|
|
|
|
|
|
|
|
|
|
def _get_id_match(self, argument): |
|
|
|
return self._id_regex.match(argument) |
|
|
|
class IDConverter(Converter[T_co]): |
|
|
|
@staticmethod |
|
|
|
def _get_id_match(argument): |
|
|
|
return _ID_REGEX.match(argument) |
|
|
|
|
|
|
|
|
|
|
|
class MemberConverter(IDConverter[discord.Member]): |
|
|
@ -351,20 +352,24 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: |
|
|
|
return self._resolve_channel(ctx, argument, ctx.guild.text_channels, discord.TextChannel) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _resolve_channel(ctx: Context, argument: str, iterable: Iterable[CT], type: Type[CT]) -> CT: |
|
|
|
bot = ctx.bot |
|
|
|
|
|
|
|
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) |
|
|
|
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) |
|
|
|
result = None |
|
|
|
guild = ctx.guild |
|
|
|
|
|
|
|
if match is None: |
|
|
|
# not a mention |
|
|
|
if guild: |
|
|
|
result = discord.utils.get(guild.text_channels, name=argument) |
|
|
|
result: Optional[CT] = discord.utils.get(iterable, name=argument) |
|
|
|
else: |
|
|
|
|
|
|
|
def check(c): |
|
|
|
return isinstance(c, discord.TextChannel) and c.name == argument |
|
|
|
return isinstance(c, type) and c.name == argument |
|
|
|
|
|
|
|
result = discord.utils.find(check, bot.get_all_channels()) |
|
|
|
else: |
|
|
@ -374,7 +379,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): |
|
|
|
else: |
|
|
|
result = _get_from_guilds(bot, 'get_channel', channel_id) |
|
|
|
|
|
|
|
if not isinstance(result, discord.TextChannel): |
|
|
|
if not isinstance(result, type): |
|
|
|
raise ChannelNotFound(argument) |
|
|
|
|
|
|
|
return result |
|
|
@ -397,32 +402,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: |
|
|
|
bot = ctx.bot |
|
|
|
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) |
|
|
|
result = None |
|
|
|
guild = ctx.guild |
|
|
|
|
|
|
|
if match is None: |
|
|
|
# not a mention |
|
|
|
if guild: |
|
|
|
result = discord.utils.get(guild.voice_channels, name=argument) |
|
|
|
else: |
|
|
|
|
|
|
|
def check(c): |
|
|
|
return isinstance(c, discord.VoiceChannel) and c.name == argument |
|
|
|
|
|
|
|
result = discord.utils.find(check, bot.get_all_channels()) |
|
|
|
else: |
|
|
|
channel_id = int(match.group(1)) |
|
|
|
if guild: |
|
|
|
result = guild.get_channel(channel_id) |
|
|
|
else: |
|
|
|
result = _get_from_guilds(bot, 'get_channel', channel_id) |
|
|
|
|
|
|
|
if not isinstance(result, discord.VoiceChannel): |
|
|
|
raise ChannelNotFound(argument) |
|
|
|
|
|
|
|
return result |
|
|
|
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.voice_channels, discord.VoiceChannel) |
|
|
|
|
|
|
|
|
|
|
|
class StageChannelConverter(IDConverter[discord.StageChannel]): |
|
|
@ -441,32 +421,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: |
|
|
|
bot = ctx.bot |
|
|
|
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) |
|
|
|
result = None |
|
|
|
guild = ctx.guild |
|
|
|
|
|
|
|
if match is None: |
|
|
|
# not a mention |
|
|
|
if guild: |
|
|
|
result = discord.utils.get(guild.stage_channels, name=argument) |
|
|
|
else: |
|
|
|
|
|
|
|
def check(c): |
|
|
|
return isinstance(c, discord.StageChannel) and c.name == argument |
|
|
|
|
|
|
|
result = discord.utils.find(check, bot.get_all_channels()) |
|
|
|
else: |
|
|
|
channel_id = int(match.group(1)) |
|
|
|
if guild: |
|
|
|
result = guild.get_channel(channel_id) |
|
|
|
else: |
|
|
|
result = _get_from_guilds(bot, 'get_channel', channel_id) |
|
|
|
|
|
|
|
if not isinstance(result, discord.StageChannel): |
|
|
|
raise ChannelNotFound(argument) |
|
|
|
|
|
|
|
return result |
|
|
|
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.stage_channels, discord.StageChannel) |
|
|
|
|
|
|
|
|
|
|
|
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): |
|
|
@ -486,33 +441,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: |
|
|
|
bot = ctx.bot |
|
|
|
|
|
|
|
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) |
|
|
|
result = None |
|
|
|
guild = ctx.guild |
|
|
|
|
|
|
|
if match is None: |
|
|
|
# not a mention |
|
|
|
if guild: |
|
|
|
result = discord.utils.get(guild.categories, name=argument) |
|
|
|
else: |
|
|
|
|
|
|
|
def check(c): |
|
|
|
return isinstance(c, discord.CategoryChannel) and c.name == argument |
|
|
|
|
|
|
|
result = discord.utils.find(check, bot.get_all_channels()) |
|
|
|
else: |
|
|
|
channel_id = int(match.group(1)) |
|
|
|
if guild: |
|
|
|
result = guild.get_channel(channel_id) |
|
|
|
else: |
|
|
|
result = _get_from_guilds(bot, 'get_channel', channel_id) |
|
|
|
|
|
|
|
if not isinstance(result, discord.CategoryChannel): |
|
|
|
raise ChannelNotFound(argument) |
|
|
|
|
|
|
|
return result |
|
|
|
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.categories, discord.CategoryChannel) |
|
|
|
|
|
|
|
|
|
|
|
class StoreChannelConverter(IDConverter[discord.StoreChannel]): |
|
|
@ -531,32 +460,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): |
|
|
|
""" |
|
|
|
|
|
|
|
async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel: |
|
|
|
bot = ctx.bot |
|
|
|
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) |
|
|
|
result = None |
|
|
|
guild = ctx.guild |
|
|
|
|
|
|
|
if match is None: |
|
|
|
# not a mention |
|
|
|
if guild: |
|
|
|
result = discord.utils.get(guild.channels, name=argument) |
|
|
|
else: |
|
|
|
|
|
|
|
def check(c): |
|
|
|
return isinstance(c, discord.StoreChannel) and c.name == argument |
|
|
|
|
|
|
|
result = discord.utils.find(check, bot.get_all_channels()) |
|
|
|
else: |
|
|
|
channel_id = int(match.group(1)) |
|
|
|
if guild: |
|
|
|
result = guild.get_channel(channel_id) |
|
|
|
else: |
|
|
|
result = _get_from_guilds(bot, 'get_channel', channel_id) |
|
|
|
|
|
|
|
if not isinstance(result, discord.StoreChannel): |
|
|
|
raise ChannelNotFound(argument) |
|
|
|
|
|
|
|
return result |
|
|
|
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.channels, discord.StoreChannel) |
|
|
|
|
|
|
|
|
|
|
|
class ColourConverter(Converter[discord.Colour]): |
|
|
@ -865,10 +769,12 @@ class clean_content(Converter[str]): |
|
|
|
r = _find(_id) |
|
|
|
return '@' + r.name if r else '@deleted-role' |
|
|
|
|
|
|
|
# fmt: off |
|
|
|
transformations.update( |
|
|
|
(f'<@&{role_id}>', resolve_role(role_id)) |
|
|
|
for role_id in message.raw_role_mentions |
|
|
|
) # fmt: off |
|
|
|
) |
|
|
|
# fmt: on |
|
|
|
|
|
|
|
def repl(obj): |
|
|
|
return transformations.get(obj.group(0), '') |
|
|
|