Browse Source

[commands] Minimise code duplication in channel converters

pull/6695/head
Nadir Chowdhury 4 years ago
committed by GitHub
parent
commit
353737239a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 140
      discord/ext/commands/converter.py

140
discord/ext/commands/converter.py

@ -26,7 +26,7 @@ from __future__ import annotations
import re
import inspect
from typing import TYPE_CHECKING, List, Protocol, TypeVar, Tuple, Union, runtime_checkable
from typing import Iterable, Optional, TYPE_CHECKING, List, Protocol, Type, TypeVar, Tuple, Union, runtime_checkable
import discord
from .errors import *
@ -72,6 +72,7 @@ def _get_from_guilds(bot, getter, argument):
_utils_get = discord.utils.get
T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
CT = TypeVar('CT', bound=discord.abc.GuildChannel)
@runtime_checkable
@ -112,13 +113,13 @@ class Converter(Protocol[T_co]):
raise NotImplementedError('Derived classes need to implement this.')
class IDConverter(Converter[T_co]):
def __init__(self):
self._id_regex = re.compile(r'([0-9]{15,20})$')
super().__init__()
_ID_REGEX = re.compile(r'([0-9]{15,20})$')
def _get_id_match(self, argument):
return self._id_regex.match(argument)
class IDConverter(Converter[T_co]):
@staticmethod
def _get_id_match(argument):
return _ID_REGEX.match(argument)
class MemberConverter(IDConverter[discord.Member]):
@ -351,20 +352,24 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel:
return self._resolve_channel(ctx, argument, ctx.guild.text_channels, discord.TextChannel)
@staticmethod
def _resolve_channel(ctx: Context, argument: str, iterable: Iterable[CT], type: Type[CT]) -> CT:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
result = None
guild = ctx.guild
if match is None:
# not a mention
if guild:
result = discord.utils.get(guild.text_channels, name=argument)
result: Optional[CT] = discord.utils.get(iterable, name=argument)
else:
def check(c):
return isinstance(c, discord.TextChannel) and c.name == argument
return isinstance(c, type) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
else:
@ -374,7 +379,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
else:
result = _get_from_guilds(bot, 'get_channel', channel_id)
if not isinstance(result, discord.TextChannel):
if not isinstance(result, type):
raise ChannelNotFound(argument)
return result
@ -397,32 +402,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
result = None
guild = ctx.guild
if match is None:
# not a mention
if guild:
result = discord.utils.get(guild.voice_channels, name=argument)
else:
def check(c):
return isinstance(c, discord.VoiceChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
else:
channel_id = int(match.group(1))
if guild:
result = guild.get_channel(channel_id)
else:
result = _get_from_guilds(bot, 'get_channel', channel_id)
if not isinstance(result, discord.VoiceChannel):
raise ChannelNotFound(argument)
return result
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.voice_channels, discord.VoiceChannel)
class StageChannelConverter(IDConverter[discord.StageChannel]):
@ -441,32 +421,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
result = None
guild = ctx.guild
if match is None:
# not a mention
if guild:
result = discord.utils.get(guild.stage_channels, name=argument)
else:
def check(c):
return isinstance(c, discord.StageChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
else:
channel_id = int(match.group(1))
if guild:
result = guild.get_channel(channel_id)
else:
result = _get_from_guilds(bot, 'get_channel', channel_id)
if not isinstance(result, discord.StageChannel):
raise ChannelNotFound(argument)
return result
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.stage_channels, discord.StageChannel)
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
@ -486,33 +441,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
result = None
guild = ctx.guild
if match is None:
# not a mention
if guild:
result = discord.utils.get(guild.categories, name=argument)
else:
def check(c):
return isinstance(c, discord.CategoryChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
else:
channel_id = int(match.group(1))
if guild:
result = guild.get_channel(channel_id)
else:
result = _get_from_guilds(bot, 'get_channel', channel_id)
if not isinstance(result, discord.CategoryChannel):
raise ChannelNotFound(argument)
return result
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.categories, discord.CategoryChannel)
class StoreChannelConverter(IDConverter[discord.StoreChannel]):
@ -531,32 +460,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
result = None
guild = ctx.guild
if match is None:
# not a mention
if guild:
result = discord.utils.get(guild.channels, name=argument)
else:
def check(c):
return isinstance(c, discord.StoreChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
else:
channel_id = int(match.group(1))
if guild:
result = guild.get_channel(channel_id)
else:
result = _get_from_guilds(bot, 'get_channel', channel_id)
if not isinstance(result, discord.StoreChannel):
raise ChannelNotFound(argument)
return result
return TextChannelConverter._resolve_channel(ctx, argument, ctx.guild.channels, discord.StoreChannel)
class ColourConverter(Converter[discord.Colour]):
@ -865,10 +769,12 @@ class clean_content(Converter[str]):
r = _find(_id)
return '@' + r.name if r else '@deleted-role'
# fmt: off
transformations.update(
(f'<@&{role_id}>', resolve_role(role_id))
for role_id in message.raw_role_mentions
) # fmt: off
)
# fmt: on
def repl(obj):
return transformations.get(obj.group(0), '')

Loading…
Cancel
Save