From 53b48904358866e62c6afec1a548424f12c7e1d1 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Wed, 13 Sep 2017 09:38:05 -0400 Subject: [PATCH] Add category support. This adds: * CategoryChannel, which represents a category * Guild.by_category() which traverses the channels grouping by category * Guild.categories to get a list of categories * abc.GuildChannel.category to get the category a channel belongs to * sync_permissions keyword argument to abc.GuildChannel.edit to sync permissions with a pre-existing or new category * category keyword argument to abc.GuildChannel.edit to move a channel to a category --- discord/abc.py | 51 +++++++++++++++- discord/channel.py | 142 ++++++++++++++++++++++++++++++++++++++++----- discord/enums.py | 9 +-- discord/guild.py | 54 +++++++++++++++-- discord/http.py | 6 +- discord/state.py | 8 +-- docs/api.rst | 7 +++ docs/migrating.rst | 4 ++ 8 files changed, 249 insertions(+), 32 deletions(-) diff --git a/discord/abc.py b/discord/abc.py index 33a471717..5be6ac104 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -188,7 +188,7 @@ class GuildChannel: return self.name @asyncio.coroutine - def _move(self, position, *, reason): + def _move(self, position, parent_id=None, lock_permissions=False, *, reason): if position < 0: raise InvalidArgument('Channel position cannot be less than 0.') @@ -211,8 +211,45 @@ class GuildChannel: # add ourselves at our designated position channels.insert(position, self) - payload = [{'id': c.id, 'position': index } for index, c in enumerate(channels)] - yield from http.move_channel_position(self.guild.id, payload, reason=reason) + payload = [] + for index, c in enumerate(channels): + d = {'id': c.id, 'position': index} + if parent_id is not _undefined and c.id == self.id: + d.update(parent_id=parent_id, lock_permissions=lock_permissions) + payload.append(d) + + yield from http.bulk_channel_update(self.guild.id, payload, reason=reason) + self.position = position + if parent_id is not _undefined: + self.category_id = int(parent_id) + + @asyncio.coroutine + def _edit(self, options, reason): + try: + parent = options.pop('category') + except KeyError: + parent_id = _undefined + else: + parent_id = parent and parent.id + + lock_permissions = options.pop('sync_permissions', False) + + try: + position = options.pop('position') + except KeyError: + if parent_id is not _undefined: + yield from self._move(self.position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason) + elif lock_permissions and self.category_id is not None: + # if we're syncing permissions on a pre-existing channel category without changing it + # we need to update the permissions to point to the pre-existing category + category = self.guild.get_channel(self.category_id) + options['permission_overwrites'] = [c._asdict() for c in category._overwrites] + else: + yield from self._move(position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason) + + if options: + data = yield from self._state.http.edit_channel(self.id, reason=reason, **options) + self._update(self.guild, data) def _fill_overwrites(self, data): self._overwrites = [] @@ -322,6 +359,14 @@ class GuildChannel: ret.append((target, overwrite)) return ret + @property + def category(self): + """Optional[:class:`CategoryChannel`]: The category this channel belongs to. + + If there is no category then this is ``None``. + """ + return self.guild.get_channel(self.category_id) + def permissions_for(self, member): """Handles permission resolution for the current :class:`Member`. diff --git a/discord/channel.py b/discord/channel.py index d81ad2735..363db258f 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -35,7 +35,7 @@ import discord.abc import time import asyncio -__all__ = ('TextChannel', 'VoiceChannel', 'DMChannel', 'GroupChannel', '_channel_factory') +__all__ = ('TextChannel', 'VoiceChannel', 'DMChannel', 'CategoryChannel', 'GroupChannel', '_channel_factory') @asyncio.coroutine def _single_delete_strategy(messages): @@ -71,6 +71,8 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): The guild the channel belongs to. id: int The channel ID. + category_id: int + The category channel ID this channel belongs to. topic: Optional[str] The channel's topic. None if it doesn't exist. position: int @@ -79,7 +81,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """ __slots__ = ( 'name', 'id', 'guild', 'topic', '_state', 'nsfw', - 'position', '_overwrites' ) + 'category_id', 'position', '_overwrites' ) def __init__(self, *, state, guild, data): self._state = state @@ -92,6 +94,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): def _update(self, guild, data): self.guild = guild self.name = data['name'] + self.category_id = utils._get_as_snowflake(data, 'parent_id') self.topic = data.get('topic') self.position = data['position'] self.nsfw = data.get('nsfw', False) @@ -140,6 +143,12 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): The new channel's position. nsfw: bool To mark the channel as NSFW or not. + sync_permissions: bool + Whether to sync permissions with the channel's new or pre-existing + category. Defaults to ``False``. + category: Optional[:class:`CategoryChannel`] + The new category for this channel. Can be ``None`` to remove the + category. reason: Optional[str] The reason for editing this channel. Shows up on the audit log. @@ -152,17 +161,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): HTTPException Editing the channel failed. """ - try: - position = options.pop('position') - except KeyError: - pass - else: - yield from self._move(position, reason=reason) - self.position = position - - if options: - data = yield from self._state.http.edit_channel(self.id, reason=reason, **options) - self._update(self.guild, data) + yield from self._edit(options, reason=reason) @asyncio.coroutine def delete_messages(self, messages): @@ -411,6 +410,8 @@ class VoiceChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): The guild the channel belongs to. id: int The channel ID. + category_id: int + The category channel ID this channel belongs to. position: int The position in the channel list. This is a number that starts at 0. e.g. the top channel is position 0. @@ -421,7 +422,7 @@ class VoiceChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): """ __slots__ = ('name', 'id', 'guild', 'bitrate', 'user_limit', - '_state', 'position', '_overwrites' ) + '_state', 'position', '_overwrites', 'category_id' ) def __init__(self, *, state, guild, data): self._state = state @@ -440,6 +441,7 @@ class VoiceChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): def _update(self, guild, data): self.guild = guild self.name = data['name'] + self.category_id = utils._get_as_snowflake(data, 'parent_id') self.position = data['position'] self.bitrate = data.get('bitrate') self.user_limit = data.get('user_limit') @@ -473,6 +475,12 @@ class VoiceChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): The new channel's user limit. position: int The new channel's position. + sync_permissions: bool + Whether to sync permissions with the channel's new or pre-existing + category. Defaults to ``False``. + category: Optional[:class:`CategoryChannel`] + The new category for this channel. Can be ``None`` to remove the + category. reason: Optional[str] The reason for editing this channel. Shows up on the audit log. @@ -484,6 +492,97 @@ class VoiceChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): Editing the channel failed. """ + yield from self._edit(options, reason=reason) + +class CategoryChannel(discord.abc.GuildChannel, Hashable): + """Represents a Discord channel category. + + These are useful to group channels to logical compartments. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the category's hash. + + .. describe:: str(x) + + Returns the category's name. + + Attributes + ----------- + name: str + The category name. + guild: :class:`Guild` + The guild the category belongs to. + id: int + The category channel ID. + position: int + The position in the category list. This is a number that starts at 0. e.g. the + top category is position 0. + """ + + __slots__ = ('name', 'id', 'guild', 'nsfw', '_state', 'position', '_overwrites', 'category_id') + + def __init__(self, *, state, guild, data): + self._state = state + self.id = int(data['id']) + self._update(guild, data) + + def __repr__(self): + return ''.format(self) + + def _update(self, guild, data): + self.guild = guild + self.name = data['name'] + self.category_id = utils._get_as_snowflake(data, 'parent_id') + self.nsfw = data.get('nsfw', False) + self.position = data['position'] + self._fill_overwrites(data) + + def is_nsfw(self): + """Checks if the category is NSFW.""" + n = self.name + return self.nsfw or n == 'nsfw' or n[:5] == 'nsfw-' + + @asyncio.coroutine + def edit(self, *, reason=None, **options): + """|coro| + + Edits the channel. + + You must have the :attr:`Permissions.manage_channel` permission to + use this. + + Parameters + ---------- + name: str + The new category's name. + position: int + The new category's position. + nsfw: bool + To mark the category as NSFW or not. + reason: Optional[str] + The reason for editing this category. Shows up on the audit log. + + Raises + ------ + InvalidArgument + If position is less than 0 or greater than the number of categories. + Forbidden + You do not have permissions to edit the category. + HTTPException + Editing the category failed. + """ + try: position = options.pop('position') except KeyError: @@ -496,6 +595,19 @@ class VoiceChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): data = yield from self._state.http.edit_channel(self.id, reason=reason, **options) self._update(self.guild, data) + @property + def channels(self): + """List[:class:`abc.GuildChannel`]: Returns the channels that are under this category. + + These are sorted by the official Discord UI, which places voice channels below the text channels. + """ + def comparator(channel): + return (not isinstance(channel, TextChannel), channel.position) + + ret = [c for c in self.guild.channels if c.category_id == self.id] + ret.sort(key=comparator) + return ret + class DMChannel(discord.abc.Messageable, Hashable): """Represents a Discord direct message channel. @@ -810,6 +922,8 @@ def _channel_factory(channel_type): return VoiceChannel, value elif value is ChannelType.private: return DMChannel, value + elif value is ChannelType.category: + return CategoryChannel, value elif value is ChannelType.group: return GroupChannel, value else: diff --git a/discord/enums.py b/discord/enums.py index 41762bf8c..88857bd1a 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -31,10 +31,11 @@ __all__ = ['ChannelType', 'MessageType', 'VoiceRegion', 'VerificationLevel', 'AuditLogAction', 'AuditLogActionCategory', 'UserFlags', ] class ChannelType(Enum): - text = 0 - private = 1 - voice = 2 - group = 3 + text = 0 + private = 1 + voice = 2 + group = 3 + category = 4 def __str__(self): return self.name diff --git a/discord/guild.py b/discord/guild.py index 6116ff0a1..3cf59aeb9 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -250,11 +250,11 @@ class Guild(Hashable): channels = data['channels'] for c in channels: if c['type'] == ChannelType.text.value: - channel = TextChannel(guild=self, data=c, state=self._state) - self._add_channel(channel) + self._add_channel(TextChannel(guild=self, data=c, state=self._state)) elif c['type'] == ChannelType.voice.value: - channel = VoiceChannel(guild=self, data=c, state=self._state) - self._add_channel(channel) + self._add_channel(VoiceChannel(guild=self, data=c, state=self._state)) + elif c['type'] == ChannelType.category.value: + self._add_channel(CategoryChannel(guild=self, data=c, state=self._state)) @property @@ -309,6 +309,52 @@ class Guild(Hashable): r.sort(key=lambda c: c.position) return r + @property + def categories(self): + """List[:class:`CategoryChannel`]: A list of categories that belongs to this guild. + + This is sorted by the position and are in UI order from top to bottom. + """ + r = [ch for ch in self._channels.values() if isinstance(ch, CategoryChannel)] + r.sort(key=lambda c: c.position) + return r + + def by_category(self): + """Returns every :class:`CategoryChannel` and their associated channels. + + These channels and categories are sorted in the official Discord UI order. + + If the channels do not have a category, then the first element of the tuple is + ``None``. + + Returns + -------- + List[Tuple[Optional[:class:`CategoryChannel`], List[:class:`abc.GuildChannel`]]]: + The categories and their associated channels. + """ + grouped = {} + for channel in self._channels.values(): + if isinstance(channel, CategoryChannel): + continue + + try: + channels = grouped[channel.category_id] + except KeyError: + channels = grouped[channel.category_id] = [] + + channels.append(channel) + + def key(t): + k, v = t + return (k.position if k else -1, v) + + _get = self._channels.get + as_list = [(_get(k), v) for k, v in grouped.items()] + as_list.sort(key=key) + for _, channels in as_list: + channels.sort(key=lambda c: c.position) + return as_list + def get_channel(self, channel_id): """Returns a :class:`abc.GuildChannel` with the given ID. If not found, returns None.""" return self._channels.get(channel_id) diff --git a/discord/http.py b/discord/http.py index aca91bb6f..458e68d33 100644 --- a/discord/http.py +++ b/discord/http.py @@ -500,16 +500,16 @@ class HTTPClient: def edit_channel(self, channel_id, *, reason=None, **options): r = Route('PATCH', '/channels/{channel_id}', channel_id=channel_id) - valid_keys = ('name', 'topic', 'bitrate', 'nsfw', 'user_limit', 'position') + valid_keys = ('name', 'topic', 'bitrate', 'nsfw', 'user_limit', 'position', 'permission_overwrites') payload = { k: v for k, v in options.items() if k in valid_keys } return self.request(r, reason=reason, json=payload) - def move_channel_position(self, guild_id, positions, *, reason=None): + def bulk_channel_update(self, guild_id, data, *, reason=None): r = Route('PATCH', '/guilds/{guild_id}/channels', guild_id=guild_id) - return self.request(r, json=positions, reason=reason) + return self.request(r, json=data, reason=reason) def create_channel(self, guild_id, name, channe_type, permission_overwrites=None, *, reason=None): payload = { diff --git a/discord/state.py b/discord/state.py index c470d11cd..b145c7ec3 100644 --- a/discord/state.py +++ b/discord/state.py @@ -484,6 +484,10 @@ class ConnectionState: def parse_channel_create(self, data): factory, ch_type = _channel_factory(data['type']) + if factory is None: + log.warning('CHANNEL_CREATE referencing an unknown channel type %s. Discarding.', data['type']) + return + channel = None if ch_type in (ChannelType.group, ChannelType.private): @@ -496,10 +500,6 @@ class ConnectionState: guild_id = utils._get_as_snowflake(data, 'guild_id') guild = self._get_guild(guild_id) if guild is not None: - if factory is None: - log.warning('CHANNEL_CREATE referencing an unknown channel type %s. Discarding.', data['type']) - return - channel = factory(guild=guild, state=self, data=data) guild._add_channel(channel) self.dispatch('guild_channel_create', channel) diff --git a/docs/api.rst b/docs/api.rst index 14381d64b..5f6245ea6 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1912,6 +1912,13 @@ VoiceChannel :members: :inherited-members: +CategoryChannel +~~~~~~~~~~~~~ + +.. autoclass:: CategoryChannel() + :members: + :inherited-members: + DMChannel ~~~~~~~~~ diff --git a/docs/migrating.rst b/docs/migrating.rst index c0434626b..57d4f3a6d 100644 --- a/docs/migrating.rst +++ b/docs/migrating.rst @@ -366,11 +366,15 @@ They will be enumerated here. **Added** - :class:`Attachment` to represent a discord attachment. +- :class:`CategoryChannel` to represent a channel category. - :attr:`VoiceChannel.members` for fetching members connected to a voice channel. - :attr:`TextChannel.members` for fetching members that can see the channel. - :attr:`Role.members` for fetching members that have the role. - :attr:`Guild.text_channels` for fetching text channels only. - :attr:`Guild.voice_channels` for fetching voice channels only. +- :attr:`Guild.categories` for fetching channel categories only. +- :attr:`TextChannel.category` and :attr:`VoiceChannel.category` to get the category a channel belongs to. +- :meth:`Guild.by_category` to get channels grouped by their category. - :attr:`Guild.chunked` to check member chunking status. - :attr:`Guild.explicit_content_filter` to fetch the content filter. - :attr:`Guild.shard_id` to get a guild's Shard ID if you're sharding.