From 7f7a0acd2187259f474a48d281fa0a6e01ef0dec Mon Sep 17 00:00:00 2001 From: Rapptz Date: Fri, 29 Sep 2023 19:04:03 -0400 Subject: [PATCH] Fix GuildChannel subclasses not working with default select values This also fixes it so ClientUser is respected as well --- discord/ui/select.py | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/discord/ui/select.py b/discord/ui/select.py index 6433b64e5..c6f125604 100644 --- a/discord/ui/select.py +++ b/discord/ui/select.py @@ -55,7 +55,9 @@ from ..app_commands.namespace import Namespace from ..member import Member from ..object import Object from ..role import Role -from ..user import User +from ..user import User, ClientUser + +# from ..channel import TextChannel, VoiceChannel, StageChannel, CategoryChannel, ForumChannel from ..abc import GuildChannel from ..threads import Thread @@ -69,7 +71,7 @@ __all__ = ( ) if TYPE_CHECKING: - from typing_extensions import TypeAlias, Self + from typing_extensions import TypeAlias, Self, TypeGuard from .view import View from ..types.components import SelectMenu as SelectMenuPayload @@ -92,6 +94,7 @@ if TYPE_CHECKING: Object, Role, Member, + ClientUser, User, GuildChannel, AppCommandChannel, @@ -107,18 +110,26 @@ RoleSelectT = TypeVar('RoleSelectT', bound='RoleSelect[Any]') ChannelSelectT = TypeVar('ChannelSelectT', bound='ChannelSelect[Any]') MentionableSelectT = TypeVar('MentionableSelectT', bound='MentionableSelect[Any]') SelectCallbackDecorator: TypeAlias = Callable[[ItemCallbackType[V, BaseSelectT]], BaseSelectT] +DefaultSelectComponentTypes = Literal[ + ComponentType.user_select, + ComponentType.role_select, + ComponentType.channel_select, + ComponentType.mentionable_select, +] selected_values: ContextVar[Dict[str, List[PossibleValue]]] = ContextVar('selected_values') +def _is_valid_object_type( + obj: Any, + component_type: DefaultSelectComponentTypes, + type_to_supported_classes: Dict[ValidSelectType, Tuple[Type[ValidDefaultValues], ...]], +) -> TypeGuard[Type[ValidDefaultValues]]: + return issubclass(obj, type_to_supported_classes[component_type]) + + def _handle_select_defaults( - defaults: Sequence[ValidDefaultValues], - component_type: Literal[ - ComponentType.user_select, - ComponentType.role_select, - ComponentType.channel_select, - ComponentType.mentionable_select, - ], + defaults: Sequence[ValidDefaultValues], component_type: DefaultSelectComponentTypes ) -> List[SelectDefaultValue]: if not defaults or defaults is MISSING: return [] @@ -128,6 +139,7 @@ def _handle_select_defaults( cls_to_type: Dict[Type[ValidDefaultValues], SelectDefaultValueType] = { User: SelectDefaultValueType.user, Member: SelectDefaultValueType.user, + ClientUser: SelectDefaultValueType.user, Role: SelectDefaultValueType.role, GuildChannel: SelectDefaultValueType.channel, AppCommandChannel: SelectDefaultValueType.channel, @@ -135,10 +147,10 @@ def _handle_select_defaults( Thread: SelectDefaultValueType.channel, } type_to_supported_classes: Dict[ValidSelectType, Tuple[Type[ValidDefaultValues], ...]] = { - ComponentType.user_select: (User, Member, Object), + ComponentType.user_select: (User, ClientUser, Member, Object), ComponentType.role_select: (Role, Object), ComponentType.channel_select: (GuildChannel, AppCommandChannel, AppCommandThread, Thread, Object), - ComponentType.mentionable_select: (User, Member, Role, Object), + ComponentType.mentionable_select: (User, ClientUser, Member, Role, Object), } values: List[SelectDefaultValue] = [] @@ -149,7 +161,7 @@ def _handle_select_defaults( object_type = obj.__class__ if not isinstance(obj, Object) else obj.type - if object_type not in type_to_supported_classes[component_type]: + if not _is_valid_object_type(object_type, component_type, type_to_supported_classes): # TODO: split this into a util function supported_classes = [c.__name__ for c in type_to_supported_classes[component_type]] if len(supported_classes) > 2: @@ -173,6 +185,9 @@ def _handle_select_defaults( elif component_type is ComponentType.channel_select: object_type = GuildChannel + if issubclass(object_type, GuildChannel): + object_type = GuildChannel + values.append(SelectDefaultValue(id=obj.id, type=cls_to_type[object_type])) return values