diff --git a/discord/audit_logs.py b/discord/audit_logs.py index 51a7d275b..ee073944d 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -159,7 +159,7 @@ def _transform_overloaded_flags(entry: AuditLogEntry, data: int) -> Union[int, f def _transform_forum_tags(entry: AuditLogEntry, data: List[ForumTagPayload]) -> List[ForumTag]: - return [ForumTag.from_data(state=entry._state, data=d) for d in data] + return [ForumTag.from_data(state=entry._state, data=d, channel_id=entry.id) for d in data] def _transform_default_reaction(entry: AuditLogEntry, data: DefaultReactionPayload) -> Optional[PartialEmoji]: diff --git a/discord/channel.py b/discord/channel.py index 9bbb93700..28e0e88df 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -53,7 +53,7 @@ from .mixins import Hashable from . import utils from .utils import MISSING from .asset import Asset -from .errors import ClientException +from .errors import ClientException, DiscordException from .stage_instance import StageInstance from .threads import Thread from .partial_emoji import _EmojiTag, PartialEmoji @@ -2154,23 +2154,29 @@ class ForumTag(Hashable): Note that if the emoji is a custom emoji, it will *not* have name information. """ - __slots__ = ('name', 'id', 'moderated', 'emoji') + __slots__ = ('name', 'id', 'moderated', 'emoji', '_state', '_channel_id') def __init__(self, *, name: str, emoji: Optional[EmojiInputType] = None, moderated: bool = False) -> None: + self._state = None + self._channel_id: Optional[int] = None self.name: str = name self.id: int = 0 self.moderated: bool = moderated self.emoji: Optional[PartialEmoji] = None if isinstance(emoji, _EmojiTag): self.emoji = emoji._to_partial() + if not self._state and emoji._state: + self._state = emoji._state elif isinstance(emoji, str): self.emoji = PartialEmoji.from_str(emoji) elif emoji is not None: raise TypeError(f'emoji must be a Emoji, PartialEmoji, str or None not {emoji.__class__.__name__}') @classmethod - def from_data(cls, *, state: ConnectionState, data: ForumTagPayload) -> Self: + def from_data(cls, *, state: ConnectionState, data: ForumTagPayload, channel_id: int) -> Self: self = cls.__new__(cls) + self._state = state + self._channel_id = channel_id self.name = data['name'] self.id = int(data['id']) self.moderated = data.get('moderated', False) @@ -2204,6 +2210,79 @@ class ForumTag(Hashable): def __str__(self) -> str: return self.name + async def edit( + self, + *, + name: str = MISSING, + emoji: Optional[PartialEmoji] = MISSING, + moderated: bool = MISSING, + reason: Optional[str] = None, + ) -> ForumTag: + """|coro| + + Edits this forum tag. + + .. versionadded:: 2.1 + + Parameters + ---------- + name: :class:`str` + The name of the tag. Can only be up to 20 characters. + emoji: Optional[Union[:class:`str`, :class:`PartialEmoji`]] + The emoji to use for the tag. + moderated: :class:`bool` + Whether the tag can only be applied by moderators. + reason: Optional[:class:`str`] + The reason for creating this tag. Shows up on the audit log. + + Raises + ------- + DiscordException + There is no internal connection state. + Forbidden + You do not have permissions to edit this forum tag. + HTTPException + Editing the forum tag failed. + + Returns + -------- + :class:`ForumTag` + The newly edited forum tag. + """ + if not self._state or not self._channel_id: + raise DiscordException('Invalid state (no ConnectionState provided)') + result = ForumTag( + name=name or self.name, + emoji=emoji if emoji is not MISSING else self.emoji, + moderated=moderated if moderated is not MISSING else self.moderated, + ) + result._state = self._state + result._channel_id = self._channel_id + await self._state.http.edit_forum_tag(self._channel_id, self.id, **result.to_dict(), reason=reason) + + result.id = self.id + return result + + async def delete(self) -> None: + """|coro| + + Deletes this forum tag. + + .. versionadded:: 2.1 + + Raises + ------- + DiscordException + There is no internal connection state. + Forbidden + You do not have permissions to delete this forum tag. + HTTPException + Deleting the forum tag failed. + """ + if not self._state or not self._channel_id: + raise DiscordException('Invalid state (no ConnectionState provided)') + await self._state.http.delete_forum_tag(self._channel_id, self.id) + class ForumChannel(discord.abc.GuildChannel, Hashable): """Represents a Discord guild forum channel. @@ -2318,7 +2397,9 @@ class ForumChannel(discord.abc.GuildChannel, Hashable): self.default_auto_archive_duration: ThreadArchiveDuration = data.get('default_auto_archive_duration', 1440) self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id') # This takes advantage of the fact that dicts are ordered since Python 3.7 - tags = [ForumTag.from_data(state=self._state, data=tag) for tag in data.get('available_tags', [])] + tags = [ + ForumTag.from_data(state=self._state, data=tag, channel_id=self.id) for tag in data.get('available_tags', []) + ] self.default_thread_slowmode_delay: int = data.get('default_thread_rate_limit_per_user', 0) self.default_layout: ForumLayoutType = try_enum(ForumLayoutType, data.get('default_forum_layout', 0)) self._available_tags: Dict[int, ForumTag] = {tag.id: tag for tag in tags} @@ -2619,15 +2700,12 @@ class ForumChannel(discord.abc.GuildChannel, Hashable): :class:`ForumTag` The newly created tag. """ - - prior = list(self._available_tags.values()) result = ForumTag(name=name, emoji=emoji, moderated=moderated) - prior.append(result) - payload = await self._state.http.edit_channel( - self.id, reason=reason, available_tags=[tag.to_dict() for tag in prior] - ) + result._state = self._state + result._channel_id = self.id + payload = await self._state.http.create_forum_tag(self.id, **result.to_dict(), reason=reason) try: - result.id = int(payload['available_tags'][-1]['id']) # type: ignore + result.id = int(payload['available_tags'][-1]['id']) except (KeyError, IndexError, ValueError): pass diff --git a/discord/http.py b/discord/http.py index b03aab09e..fbd8aed2a 100644 --- a/discord/http.py +++ b/discord/http.py @@ -1701,6 +1701,55 @@ class HTTPClient: params['limit'] = limit return self.request(route, params=params) + def create_forum_tag( + self, + channel_id: Snowflake, + *, + name: str, + emoji_id: Optional[Snowflake] = None, + emoji_name: Optional[str] = None, + moderated: bool = False, + reason: Optional[str] = None, + ) -> Response[channel.ForumChannel]: + payload: Dict[str, Any] = { + 'name': name, + } + if emoji_id: + payload['emoji_id'] = emoji_id + if emoji_name: + payload['emoji_name'] = emoji_name + if moderated: + payload['moderated'] = True + + return self.request(Route('POST', '/channels/{channel_id}/tags', channel_id=channel_id), json=payload, reason=reason) + + def edit_forum_tag( + self, + channel_id: Snowflake, + tag_id: Snowflake, + *, + name: str, + emoji_id: Optional[Snowflake] = None, + emoji_name: Optional[str] = None, + moderated: bool = False, + reason: Optional[str] = None, + ) -> Response[channel.ForumChannel]: + payload: Dict[str, Any] = { + 'name': name, + 'emoji_id': emoji_id, + 'emoji_name': emoji_name, + 'moderated': moderated, + } + + return self.request( + Route('PUT', '/channels/{channel_id}/tags/{tag_id}', channel_id=channel_id, tag_id=tag_id), + json=payload, + reason=reason, + ) + + def delete_forum_tag(self, channel_id: Snowflake, tag_id: Snowflake) -> Response[channel.ForumChannel]: + return self.request(Route('DELETE', '/channels/{channel_id}/tags/{tag_id}', channel_id=channel_id, tag_id=tag_id)) + # Webhook management def create_webhook(