From c5ecc42c72779fedabff8c0eef3d3e17c5042a34 Mon Sep 17 00:00:00 2001 From: Soheab_ <33902984+Soheab@users.noreply.github.com> Date: Fri, 29 Sep 2023 21:55:20 +0200 Subject: [PATCH] Add support for default_values field on selects --- discord/components.py | 84 +++++++++++++- discord/enums.py | 7 ++ discord/types/components.py | 11 ++ discord/ui/select.py | 221 +++++++++++++++++++++++++++++++++++- docs/api.rst | 18 +++ docs/interactions/api.rst | 8 ++ 6 files changed, 342 insertions(+), 7 deletions(-) diff --git a/discord/components.py b/discord/components.py index 6a8345801..297f815fe 100644 --- a/discord/components.py +++ b/discord/components.py @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import ClassVar, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload -from .enums import try_enum, ComponentType, ButtonStyle, TextStyle, ChannelType +from .enums import try_enum, ComponentType, ButtonStyle, TextStyle, ChannelType, SelectDefaultValueType from .utils import get_slots, MISSING from .partial_emoji import PartialEmoji, _EmojiTag @@ -40,8 +40,10 @@ if TYPE_CHECKING: ActionRow as ActionRowPayload, TextInput as TextInputPayload, ActionRowChildComponent as ActionRowChildComponentPayload, + SelectDefaultValues as SelectDefaultValuesPayload, ) from .emoji import Emoji + from .abc import Snowflake ActionRowChildComponentType = Union['Button', 'SelectMenu', 'TextInput'] @@ -53,6 +55,7 @@ __all__ = ( 'SelectMenu', 'SelectOption', 'TextInput', + 'SelectDefaultValue', ) @@ -263,6 +266,7 @@ class SelectMenu(Component): 'options', 'disabled', 'channel_types', + 'default_values', ) __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ @@ -276,6 +280,9 @@ class SelectMenu(Component): self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])] self.disabled: bool = data.get('disabled', False) self.channel_types: List[ChannelType] = [try_enum(ChannelType, t) for t in data.get('channel_types', [])] + self.default_values: List[SelectDefaultValue] = [ + SelectDefaultValue.from_dict(d) for d in data.get('default_values', []) + ] def to_dict(self) -> SelectMenuPayload: payload: SelectMenuPayload = { @@ -291,6 +298,8 @@ class SelectMenu(Component): payload['options'] = [op.to_dict() for op in self.options] if self.channel_types: payload['channel_types'] = [t.value for t in self.channel_types] + if self.default_values: + payload["default_values"] = [v.to_dict() for v in self.default_values] return payload @@ -512,6 +521,79 @@ class TextInput(Component): return self.value +class SelectDefaultValue: + """Represents a select menu's default value. + + These can be created by users. + + .. versionadded:: 2.4 + + Parameters + ----------- + id: :class:`int` + The id of a role, user, or channel. + type: :class:`SelectDefaultValueType` + The type of value that ``id`` represents. + """ + + def __init__( + self, + *, + id: int, + type: SelectDefaultValueType, + ) -> None: + self.id: int = id + self._type: SelectDefaultValueType = type + + @property + def type(self) -> SelectDefaultValueType: + return self._type + + @type.setter + def type(self, value: SelectDefaultValueType) -> None: + if not isinstance(value, SelectDefaultValueType): + raise TypeError(f'expected SelectDefaultValueType, received {value.__class__.__name__} instead') + + self._type = value + + def __repr__(self) -> str: + return f'' + + @classmethod + def from_dict(cls, data: SelectDefaultValuesPayload) -> SelectDefaultValue: + return cls( + id=data['id'], + type=try_enum(SelectDefaultValueType, data['type']), + ) + + def to_dict(self) -> SelectDefaultValuesPayload: + return { + 'id': self.id, + 'type': self._type.value, + } + + @classmethod + def from_channel(cls, channel: Snowflake, /) -> Self: + return cls( + id=channel.id, + type=SelectDefaultValueType.channel, + ) + + @classmethod + def from_role(cls, role: Snowflake, /) -> Self: + return cls( + id=role.id, + type=SelectDefaultValueType.role, + ) + + @classmethod + def from_user(cls, user: Snowflake, /) -> Self: + return cls( + id=user.id, + type=SelectDefaultValueType.user, + ) + + @overload def _component_factory(data: ActionRowChildComponentPayload) -> Optional[ActionRowChildComponentType]: ... diff --git a/discord/enums.py b/discord/enums.py index c0a2c3f43..254f86bc7 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -69,6 +69,7 @@ __all__ = ( 'AutoModRuleActionType', 'ForumLayoutType', 'ForumOrderType', + 'SelectDefaultValueType', ) if TYPE_CHECKING: @@ -772,6 +773,12 @@ class ForumOrderType(Enum): creation_date = 1 +class SelectDefaultValueType(Enum): + user = 'user' + role = 'role' + channel = 'channel' + + def create_unknown_value(cls: Type[E], val: Any) -> E: value_cls = cls._enum_value_cls_ # type: ignore # This is narrowed below name = f'unknown_{val}' diff --git a/discord/types/components.py b/discord/types/components.py index f1790ff35..218f5cef0 100644 --- a/discord/types/components.py +++ b/discord/types/components.py @@ -33,6 +33,7 @@ from .channel import ChannelType ComponentType = Literal[1, 2, 3, 4] ButtonStyle = Literal[1, 2, 3, 4, 5] TextStyle = Literal[1, 2] +DefaultValueType = Literal['user', 'role', 'channel'] class ActionRow(TypedDict): @@ -66,6 +67,11 @@ class SelectComponent(TypedDict): disabled: NotRequired[bool] +class SelectDefaultValues(TypedDict): + id: int + type: DefaultValueType + + class StringSelectComponent(SelectComponent): type: Literal[3] options: NotRequired[List[SelectOption]] @@ -73,19 +79,23 @@ class StringSelectComponent(SelectComponent): class UserSelectComponent(SelectComponent): type: Literal[5] + default_values: NotRequired[List[SelectDefaultValues]] class RoleSelectComponent(SelectComponent): type: Literal[6] + default_values: NotRequired[List[SelectDefaultValues]] class MentionableSelectComponent(SelectComponent): type: Literal[7] + default_values: NotRequired[List[SelectDefaultValues]] class ChannelSelectComponent(SelectComponent): type: Literal[8] channel_types: NotRequired[List[ChannelType]] + default_values: NotRequired[List[SelectDefaultValues]] class TextInput(TypedDict): @@ -104,6 +114,7 @@ class SelectMenu(SelectComponent): type: Literal[3, 5, 6, 7, 8] options: NotRequired[List[SelectOption]] channel_types: NotRequired[List[ChannelType]] + default_values: NotRequired[List[SelectDefaultValues]] ActionRowChildComponent = Union[ButtonComponent, SelectMenu, TextInput] diff --git a/discord/ui/select.py b/discord/ui/select.py index e54180cac..6433b64e5 100644 --- a/discord/ui/select.py +++ b/discord/ui/select.py @@ -22,21 +22,42 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from __future__ import annotations -from typing import Any, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Callable, Union, Dict, overload +from typing import ( + Any, + List, + Literal, + Optional, + TYPE_CHECKING, + Tuple, + Type, + TypeVar, + Callable, + Union, + Dict, + overload, + Sequence, +) from contextvars import ContextVar import inspect import os from .item import Item, ItemCallbackType -from ..enums import ChannelType, ComponentType +from ..enums import ChannelType, ComponentType, SelectDefaultValueType from ..partial_emoji import PartialEmoji from ..emoji import Emoji from ..utils import MISSING from ..components import ( SelectOption, SelectMenu, + SelectDefaultValue, ) from ..app_commands.namespace import Namespace +from ..member import Member +from ..object import Object +from ..role import Role +from ..user import User +from ..abc import GuildChannel +from ..threads import Thread __all__ = ( 'Select', @@ -54,9 +75,6 @@ if TYPE_CHECKING: from ..types.components import SelectMenu as SelectMenuPayload from ..types.interactions import SelectMessageComponentInteractionData from ..app_commands import AppCommandChannel, AppCommandThread - from ..member import Member - from ..role import Role - from ..user import User from ..interactions import Interaction ValidSelectType: TypeAlias = Literal[ @@ -69,6 +87,17 @@ if TYPE_CHECKING: PossibleValue: TypeAlias = Union[ str, User, Member, Role, AppCommandChannel, AppCommandThread, Union[Role, Member], Union[Role, User] ] + ValidDefaultValues: TypeAlias = Union[ + SelectDefaultValue, + Object, + Role, + Member, + User, + GuildChannel, + AppCommandChannel, + AppCommandThread, + Thread, + ] V = TypeVar('V', bound='View', covariant=True) BaseSelectT = TypeVar('BaseSelectT', bound='BaseSelect[Any]') @@ -82,6 +111,73 @@ SelectCallbackDecorator: TypeAlias = Callable[[ItemCallbackType[V, BaseSelectT]] selected_values: ContextVar[Dict[str, List[PossibleValue]]] = ContextVar('selected_values') +def _handle_select_defaults( + defaults: Sequence[ValidDefaultValues], + component_type: Literal[ + ComponentType.user_select, + ComponentType.role_select, + ComponentType.channel_select, + ComponentType.mentionable_select, + ], +) -> List[SelectDefaultValue]: + if not defaults or defaults is MISSING: + return [] + + from ..app_commands import AppCommandChannel, AppCommandThread + + cls_to_type: Dict[Type[ValidDefaultValues], SelectDefaultValueType] = { + User: SelectDefaultValueType.user, + Member: SelectDefaultValueType.user, + Role: SelectDefaultValueType.role, + GuildChannel: SelectDefaultValueType.channel, + AppCommandChannel: SelectDefaultValueType.channel, + AppCommandThread: SelectDefaultValueType.channel, + Thread: SelectDefaultValueType.channel, + } + type_to_supported_classes: Dict[ValidSelectType, Tuple[Type[ValidDefaultValues], ...]] = { + ComponentType.user_select: (User, Member, Object), + ComponentType.role_select: (Role, Object), + ComponentType.channel_select: (GuildChannel, AppCommandChannel, AppCommandThread, Thread, Object), + ComponentType.mentionable_select: (User, Member, Role, Object), + } + + values: List[SelectDefaultValue] = [] + for obj in defaults: + if isinstance(obj, SelectDefaultValue): + values.append(obj) + continue + + object_type = obj.__class__ if not isinstance(obj, Object) else obj.type + + if object_type not in type_to_supported_classes[component_type]: + # TODO: split this into a util function + supported_classes = [c.__name__ for c in type_to_supported_classes[component_type]] + if len(supported_classes) > 2: + supported_classes = ', '.join(supported_classes[:-1]) + f', or {supported_classes[-1]}' + elif len(supported_classes) == 2: + supported_classes = f'{supported_classes[0]} or {supported_classes[1]}' + else: + supported_classes = supported_classes[0] + + raise TypeError(f'Expected an instance of {supported_classes} not {object_type.__name__}') + + if object_type is Object: + if component_type is ComponentType.mentionable_select: + raise ValueError( + 'Object must have a type specified for the chosen select type. Please pass one using the `type`` kwarg.' + ) + elif component_type is ComponentType.user_select: + object_type = User + elif component_type is ComponentType.role_select: + object_type = Role + elif component_type is ComponentType.channel_select: + object_type = GuildChannel + + values.append(SelectDefaultValue(id=obj.id, type=cls_to_type[object_type])) + + return values + + class BaseSelect(Item[V]): """The base Select model that all other Select models inherit from. @@ -128,6 +224,7 @@ class BaseSelect(Item[V]): disabled: bool = False, options: List[SelectOption] = MISSING, channel_types: List[ChannelType] = MISSING, + default_values: Sequence[SelectDefaultValue] = MISSING, ) -> None: super().__init__() self._provided_custom_id = custom_id is not MISSING @@ -144,6 +241,7 @@ class BaseSelect(Item[V]): disabled=disabled, channel_types=[] if channel_types is MISSING else channel_types, options=[] if options is MISSING else options, + default_values=[] if default_values is MISSING else default_values, ) self.row = row @@ -410,6 +508,10 @@ class UserSelect(BaseSelect[V]): Defaults to 1 and must be between 1 and 25. disabled: :class:`bool` Whether the select is disabled or not. + default_values: Sequence[:class:`~discord.abc.Snowflake`] + A list of objects representing the users that should be selected by default. + + .. versionadded:: 2.4 row: Optional[:class:`int`] The relative row this select menu belongs to. A Discord component can only have 5 rows. By default, items are arranged automatically into those 5 rows. If you'd @@ -418,6 +520,8 @@ class UserSelect(BaseSelect[V]): ordering. The row number must be between 0 and 4 (i.e. zero indexed). """ + __item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ('default_values',) + def __init__( self, *, @@ -427,6 +531,7 @@ class UserSelect(BaseSelect[V]): max_values: int = 1, disabled: bool = False, row: Optional[int] = None, + default_values: Sequence[ValidDefaultValues] = MISSING, ) -> None: super().__init__( self.type, @@ -436,6 +541,7 @@ class UserSelect(BaseSelect[V]): max_values=max_values, disabled=disabled, row=row, + default_values=_handle_select_defaults(default_values, self.type), ) @property @@ -456,6 +562,18 @@ class UserSelect(BaseSelect[V]): """ return super().values # type: ignore + @property + def default_values(self) -> List[SelectDefaultValue]: + """List[:class:`discord.SelectDefaultValue`]: A list of default values for the select menu. + + .. versionadded:: 2.4 + """ + return self._underlying.default_values + + @default_values.setter + def default_values(self, value: Sequence[ValidDefaultValues]) -> None: + self._underlying.default_values = _handle_select_defaults(value, self.type) + class RoleSelect(BaseSelect[V]): """Represents a UI select menu with a list of predefined options with the current roles of the guild. @@ -479,6 +597,10 @@ class RoleSelect(BaseSelect[V]): Defaults to 1 and must be between 1 and 25. disabled: :class:`bool` Whether the select is disabled or not. + default_values: Sequence[:class:`~discord.abc.Snowflake`] + A list of objects representing the users that should be selected by default. + + .. versionadded:: 2.4 row: Optional[:class:`int`] The relative row this select menu belongs to. A Discord component can only have 5 rows. By default, items are arranged automatically into those 5 rows. If you'd @@ -487,6 +609,8 @@ class RoleSelect(BaseSelect[V]): ordering. The row number must be between 0 and 4 (i.e. zero indexed). """ + __item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ('default_values',) + def __init__( self, *, @@ -496,6 +620,7 @@ class RoleSelect(BaseSelect[V]): max_values: int = 1, disabled: bool = False, row: Optional[int] = None, + default_values: Sequence[ValidDefaultValues] = MISSING, ) -> None: super().__init__( self.type, @@ -505,6 +630,7 @@ class RoleSelect(BaseSelect[V]): max_values=max_values, disabled=disabled, row=row, + default_values=_handle_select_defaults(default_values, self.type), ) @property @@ -517,6 +643,18 @@ class RoleSelect(BaseSelect[V]): """List[:class:`discord.Role`]: A list of roles that have been selected by the user.""" return super().values # type: ignore + @property + def default_values(self) -> List[SelectDefaultValue]: + """List[:class:`discord.SelectDefaultValue`]: A list of default values for the select menu. + + .. versionadded:: 2.4 + """ + return self._underlying.default_values + + @default_values.setter + def default_values(self, value: Sequence[ValidDefaultValues]) -> None: + self._underlying.default_values = _handle_select_defaults(value, self.type) + class MentionableSelect(BaseSelect[V]): """Represents a UI select menu with a list of predefined options with the current members and roles in the guild. @@ -543,6 +681,11 @@ class MentionableSelect(BaseSelect[V]): Defaults to 1 and must be between 1 and 25. disabled: :class:`bool` Whether the select is disabled or not. + default_values: Sequence[:class:`~discord.abc.Snowflake`] + A list of objects representing the users/roles that should be selected by default. + if :class:`.Object` is passed, then the type must be specified in the constructor. + + .. versionadded:: 2.4 row: Optional[:class:`int`] The relative row this select menu belongs to. A Discord component can only have 5 rows. By default, items are arranged automatically into those 5 rows. If you'd @@ -551,6 +694,8 @@ class MentionableSelect(BaseSelect[V]): ordering. The row number must be between 0 and 4 (i.e. zero indexed). """ + __item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ('default_values',) + def __init__( self, *, @@ -560,6 +705,7 @@ class MentionableSelect(BaseSelect[V]): max_values: int = 1, disabled: bool = False, row: Optional[int] = None, + default_values: Sequence[ValidDefaultValues] = MISSING, ) -> None: super().__init__( self.type, @@ -569,6 +715,7 @@ class MentionableSelect(BaseSelect[V]): max_values=max_values, disabled=disabled, row=row, + default_values=_handle_select_defaults(default_values, self.type), ) @property @@ -589,6 +736,18 @@ class MentionableSelect(BaseSelect[V]): """ return super().values # type: ignore + @property + def default_values(self) -> List[SelectDefaultValue]: + """List[:class:`discord.SelectDefaultValue`]: A list of default values for the select menu. + + .. versionadded:: 2.4 + """ + return self._underlying.default_values + + @default_values.setter + def default_values(self, value: Sequence[ValidDefaultValues]) -> None: + self._underlying.default_values = _handle_select_defaults(value, self.type) + class ChannelSelect(BaseSelect[V]): """Represents a UI select menu with a list of predefined options with the current channels in the guild. @@ -614,6 +773,10 @@ class ChannelSelect(BaseSelect[V]): Defaults to 1 and must be between 1 and 25. disabled: :class:`bool` Whether the select is disabled or not. + default_values: Sequence[:class:`~discord.abc.Snowflake`] + A list of objects representing the channels that should be selected by default. + + .. versionadded:: 2.4 row: Optional[:class:`int`] The relative row this select menu belongs to. A Discord component can only have 5 rows. By default, items are arranged automatically into those 5 rows. If you'd @@ -622,7 +785,10 @@ class ChannelSelect(BaseSelect[V]): ordering. The row number must be between 0 and 4 (i.e. zero indexed). """ - __item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ('channel_types',) + __item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ( + 'channel_types', + 'default_values', + ) def __init__( self, @@ -634,6 +800,7 @@ class ChannelSelect(BaseSelect[V]): max_values: int = 1, disabled: bool = False, row: Optional[int] = None, + default_values: Sequence[ValidDefaultValues] = MISSING, ) -> None: super().__init__( self.type, @@ -644,6 +811,7 @@ class ChannelSelect(BaseSelect[V]): disabled=disabled, row=row, channel_types=channel_types, + default_values=_handle_select_defaults(default_values, self.type), ) @property @@ -670,6 +838,18 @@ class ChannelSelect(BaseSelect[V]): """List[Union[:class:`~discord.app_commands.AppCommandChannel`, :class:`~discord.app_commands.AppCommandThread`]]: A list of channels selected by the user.""" return super().values # type: ignore + @property + def default_values(self) -> List[SelectDefaultValue]: + """List[:class:`discord.SelectDefaultValue`]: A list of default values for the select menu. + + .. versionadded:: 2.4 + """ + return self._underlying.default_values + + @default_values.setter + def default_values(self, value: Sequence[ValidDefaultValues]) -> None: + self._underlying.default_values = _handle_select_defaults(value, self.type) + @overload def select( @@ -698,6 +878,7 @@ def select( min_values: int = ..., max_values: int = ..., disabled: bool = ..., + default_values: Sequence[ValidDefaultValues] = ..., row: Optional[int] = ..., ) -> SelectCallbackDecorator[V, UserSelectT]: ... @@ -714,6 +895,7 @@ def select( min_values: int = ..., max_values: int = ..., disabled: bool = ..., + default_values: Sequence[ValidDefaultValues] = ..., row: Optional[int] = ..., ) -> SelectCallbackDecorator[V, RoleSelectT]: ... @@ -730,6 +912,7 @@ def select( min_values: int = ..., max_values: int = ..., disabled: bool = ..., + default_values: Sequence[ValidDefaultValues] = ..., row: Optional[int] = ..., ) -> SelectCallbackDecorator[V, ChannelSelectT]: ... @@ -746,6 +929,7 @@ def select( min_values: int = ..., max_values: int = ..., disabled: bool = ..., + default_values: Sequence[ValidDefaultValues] = ..., row: Optional[int] = ..., ) -> SelectCallbackDecorator[V, MentionableSelectT]: ... @@ -761,6 +945,7 @@ def select( min_values: int = 1, max_values: int = 1, disabled: bool = False, + default_values: Sequence[ValidDefaultValues] = MISSING, row: Optional[int] = None, ) -> SelectCallbackDecorator[V, BaseSelectT]: """A decorator that attaches a select menu to a component. @@ -832,6 +1017,12 @@ def select( with :class:`ChannelSelect` instances. disabled: :class:`bool` Whether the select is disabled or not. Defaults to ``False``. + default_values: Sequence[:class:`~discord.abc.Snowflake`] + A list of objects representing the default values for the select menu. This cannot be used with regular :class:`Select` instances. + If ``cls`` is :class:`MentionableSelect` and :class:`.Object` is passed, then the type must be specified in the constructor. + if `cls` is :class:`MentionableSelect` and :class:`.Object` is passed, then the type must be specified in the constructor. + + .. versionadded:: 2.4 """ def decorator(func: ItemCallbackType[V, BaseSelectT]) -> ItemCallbackType[V, BaseSelectT]: @@ -855,6 +1046,24 @@ def select( func.__discord_ui_model_kwargs__['options'] = options if issubclass(callback_cls, ChannelSelect): func.__discord_ui_model_kwargs__['channel_types'] = channel_types + if not issubclass(callback_cls, Select): + cls_to_type: Dict[ + Type[BaseSelect], + Literal[ + ComponentType.user_select, + ComponentType.channel_select, + ComponentType.role_select, + ComponentType.mentionable_select, + ], + ] = { + UserSelect: ComponentType.user_select, + RoleSelect: ComponentType.role_select, + MentionableSelect: ComponentType.mentionable_select, + ChannelSelect: ComponentType.channel_select, + } + func.__discord_ui_model_kwargs__['default_values'] = ( + MISSING if default_values is MISSING else _handle_select_defaults(default_values, cls_to_type[callback_cls]) + ) return func diff --git a/docs/api.rst b/docs/api.rst index 4db962917..89b05a8c3 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -3392,6 +3392,24 @@ of :class:`enum.Enum`. Sort forum posts by creation time (from most recent to oldest). +.. class:: SelectDefaultValueType + + Represents the default value of a select menu. + + .. versionadded:: 2.4 + + .. attribute:: user + + The underlying type of the ID is a user. + + .. attribute:: role + + The underlying type of the ID is a role. + + .. attribute:: channel + + The underlying type of the ID is a channel or thread. + .. _discord-api-audit-logs: diff --git a/docs/interactions/api.rst b/docs/interactions/api.rst index 8e930c6ef..95c1922d1 100644 --- a/docs/interactions/api.rst +++ b/docs/interactions/api.rst @@ -166,6 +166,14 @@ SelectOption .. autoclass:: SelectOption :members: +SelectDefaultValue +~~~~~~~~~~~~~~~~~~~ + +.. attributetable:: SelectDefaultValue + +.. autoclass:: SelectDefaultValue + :members: + Choice ~~~~~~~