From 353737239a61444c0b7ca89144602bbbcf4906b6 Mon Sep 17 00:00:00 2001 From: Nadir Chowdhury Date: Sat, 10 Apr 2021 19:01:26 +0100 Subject: [PATCH] [commands] Minimise code duplication in channel converters --- discord/ext/commands/converter.py | 140 +++++------------------------- 1 file changed, 23 insertions(+), 117 deletions(-) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 730113b14..4033751fc 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -26,7 +26,7 @@ from __future__ import annotations import re import inspect -from typing import TYPE_CHECKING, List, Protocol, TypeVar, Tuple, Union, runtime_checkable +from typing import Iterable, Optional, TYPE_CHECKING, List, Protocol, Type, TypeVar, Tuple, Union, runtime_checkable import discord from .errors import * @@ -72,6 +72,7 @@ def _get_from_guilds(bot, getter, argument): _utils_get = discord.utils.get T = TypeVar('T') T_co = TypeVar('T_co', covariant=True) +CT = TypeVar('CT', bound=discord.abc.GuildChannel) @runtime_checkable @@ -112,13 +113,13 @@ class Converter(Protocol[T_co]): raise NotImplementedError('Derived classes need to implement this.') -class IDConverter(Converter[T_co]): - def __init__(self): - self._id_regex = re.compile(r'([0-9]{15,20})$') - super().__init__() +_ID_REGEX = re.compile(r'([0-9]{15,20})$') + - def _get_id_match(self, argument): - return self._id_regex.match(argument) +class IDConverter(Converter[T_co]): + @staticmethod + def _get_id_match(argument): + return _ID_REGEX.match(argument) class MemberConverter(IDConverter[discord.Member]): @@ -351,20 +352,24 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: + return self._resolve_channel(ctx, argument, ctx.guild.text_channels, discord.TextChannel) + + @staticmethod + def _resolve_channel(ctx: Context, argument: str, iterable: Iterable[CT], type: Type[CT]) -> CT: bot = ctx.bot - match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) + match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) result = None guild = ctx.guild if match is None: # not a mention if guild: - result = discord.utils.get(guild.text_channels, name=argument) + result: Optional[CT] = discord.utils.get(iterable, name=argument) else: def check(c): - return isinstance(c, discord.TextChannel) and c.name == argument + return isinstance(c, type) and c.name == argument result = discord.utils.find(check, bot.get_all_channels()) else: @@ -374,7 +379,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): else: result = _get_from_guilds(bot, 'get_channel', channel_id) - if not isinstance(result, discord.TextChannel): + if not isinstance(result, type): raise ChannelNotFound(argument) return result @@ -397,32 +402,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: - bot = ctx.bot - match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) - result = None - guild = ctx.guild - - if match is None: - # not a mention - if guild: - result = discord.utils.get(guild.voice_channels, name=argument) - else: - - def check(c): - return isinstance(c, discord.VoiceChannel) and c.name == argument - - result = discord.utils.find(check, bot.get_all_channels()) - else: - channel_id = int(match.group(1)) - if guild: - result = guild.get_channel(channel_id) - else: - result = _get_from_guilds(bot, 'get_channel', channel_id) - - if not isinstance(result, discord.VoiceChannel): - raise ChannelNotFound(argument) - - return result + return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.voice_channels, discord.VoiceChannel) class StageChannelConverter(IDConverter[discord.StageChannel]): @@ -441,32 +421,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: - bot = ctx.bot - match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) - result = None - guild = ctx.guild - - if match is None: - # not a mention - if guild: - result = discord.utils.get(guild.stage_channels, name=argument) - else: - - def check(c): - return isinstance(c, discord.StageChannel) and c.name == argument - - result = discord.utils.find(check, bot.get_all_channels()) - else: - channel_id = int(match.group(1)) - if guild: - result = guild.get_channel(channel_id) - else: - result = _get_from_guilds(bot, 'get_channel', channel_id) - - if not isinstance(result, discord.StageChannel): - raise ChannelNotFound(argument) - - return result + return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.stage_channels, discord.StageChannel) class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): @@ -486,33 +441,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: - bot = ctx.bot - - match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) - result = None - guild = ctx.guild - - if match is None: - # not a mention - if guild: - result = discord.utils.get(guild.categories, name=argument) - else: - - def check(c): - return isinstance(c, discord.CategoryChannel) and c.name == argument - - result = discord.utils.find(check, bot.get_all_channels()) - else: - channel_id = int(match.group(1)) - if guild: - result = guild.get_channel(channel_id) - else: - result = _get_from_guilds(bot, 'get_channel', channel_id) - - if not isinstance(result, discord.CategoryChannel): - raise ChannelNotFound(argument) - - return result + return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.categories, discord.CategoryChannel) class StoreChannelConverter(IDConverter[discord.StoreChannel]): @@ -531,32 +460,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel: - bot = ctx.bot - match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) - result = None - guild = ctx.guild - - if match is None: - # not a mention - if guild: - result = discord.utils.get(guild.channels, name=argument) - else: - - def check(c): - return isinstance(c, discord.StoreChannel) and c.name == argument - - result = discord.utils.find(check, bot.get_all_channels()) - else: - channel_id = int(match.group(1)) - if guild: - result = guild.get_channel(channel_id) - else: - result = _get_from_guilds(bot, 'get_channel', channel_id) - - if not isinstance(result, discord.StoreChannel): - raise ChannelNotFound(argument) - - return result + return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.channels, discord.StoreChannel) class ColourConverter(Converter[discord.Colour]): @@ -865,10 +769,12 @@ class clean_content(Converter[str]): r = _find(_id) return '@' + r.name if r else '@deleted-role' + # fmt: off transformations.update( (f'<@&{role_id}>', resolve_role(role_id)) for role_id in message.raw_role_mentions - ) # fmt: off + ) + # fmt: on def repl(obj): return transformations.get(obj.group(0), '')