Browse Source

Fix GuildChannel subclasses not working with default select values

This also fixes it so ClientUser is respected as well
pull/9586/head
Rapptz 2 years ago
parent
commit
7f7a0acd21
  1. 39
      discord/ui/select.py

39
discord/ui/select.py

@ -55,7 +55,9 @@ from ..app_commands.namespace import Namespace
from ..member import Member from ..member import Member
from ..object import Object from ..object import Object
from ..role import Role from ..role import Role
from ..user import User from ..user import User, ClientUser
# from ..channel import TextChannel, VoiceChannel, StageChannel, CategoryChannel, ForumChannel
from ..abc import GuildChannel from ..abc import GuildChannel
from ..threads import Thread from ..threads import Thread
@ -69,7 +71,7 @@ __all__ = (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import TypeAlias, Self from typing_extensions import TypeAlias, Self, TypeGuard
from .view import View from .view import View
from ..types.components import SelectMenu as SelectMenuPayload from ..types.components import SelectMenu as SelectMenuPayload
@ -92,6 +94,7 @@ if TYPE_CHECKING:
Object, Object,
Role, Role,
Member, Member,
ClientUser,
User, User,
GuildChannel, GuildChannel,
AppCommandChannel, AppCommandChannel,
@ -107,18 +110,26 @@ RoleSelectT = TypeVar('RoleSelectT', bound='RoleSelect[Any]')
ChannelSelectT = TypeVar('ChannelSelectT', bound='ChannelSelect[Any]') ChannelSelectT = TypeVar('ChannelSelectT', bound='ChannelSelect[Any]')
MentionableSelectT = TypeVar('MentionableSelectT', bound='MentionableSelect[Any]') MentionableSelectT = TypeVar('MentionableSelectT', bound='MentionableSelect[Any]')
SelectCallbackDecorator: TypeAlias = Callable[[ItemCallbackType[V, BaseSelectT]], BaseSelectT] SelectCallbackDecorator: TypeAlias = Callable[[ItemCallbackType[V, BaseSelectT]], BaseSelectT]
DefaultSelectComponentTypes = Literal[
ComponentType.user_select,
ComponentType.role_select,
ComponentType.channel_select,
ComponentType.mentionable_select,
]
selected_values: ContextVar[Dict[str, List[PossibleValue]]] = ContextVar('selected_values') selected_values: ContextVar[Dict[str, List[PossibleValue]]] = ContextVar('selected_values')
def _is_valid_object_type(
obj: Any,
component_type: DefaultSelectComponentTypes,
type_to_supported_classes: Dict[ValidSelectType, Tuple[Type[ValidDefaultValues], ...]],
) -> TypeGuard[Type[ValidDefaultValues]]:
return issubclass(obj, type_to_supported_classes[component_type])
def _handle_select_defaults( def _handle_select_defaults(
defaults: Sequence[ValidDefaultValues], defaults: Sequence[ValidDefaultValues], component_type: DefaultSelectComponentTypes
component_type: Literal[
ComponentType.user_select,
ComponentType.role_select,
ComponentType.channel_select,
ComponentType.mentionable_select,
],
) -> List[SelectDefaultValue]: ) -> List[SelectDefaultValue]:
if not defaults or defaults is MISSING: if not defaults or defaults is MISSING:
return [] return []
@ -128,6 +139,7 @@ def _handle_select_defaults(
cls_to_type: Dict[Type[ValidDefaultValues], SelectDefaultValueType] = { cls_to_type: Dict[Type[ValidDefaultValues], SelectDefaultValueType] = {
User: SelectDefaultValueType.user, User: SelectDefaultValueType.user,
Member: SelectDefaultValueType.user, Member: SelectDefaultValueType.user,
ClientUser: SelectDefaultValueType.user,
Role: SelectDefaultValueType.role, Role: SelectDefaultValueType.role,
GuildChannel: SelectDefaultValueType.channel, GuildChannel: SelectDefaultValueType.channel,
AppCommandChannel: SelectDefaultValueType.channel, AppCommandChannel: SelectDefaultValueType.channel,
@ -135,10 +147,10 @@ def _handle_select_defaults(
Thread: SelectDefaultValueType.channel, Thread: SelectDefaultValueType.channel,
} }
type_to_supported_classes: Dict[ValidSelectType, Tuple[Type[ValidDefaultValues], ...]] = { type_to_supported_classes: Dict[ValidSelectType, Tuple[Type[ValidDefaultValues], ...]] = {
ComponentType.user_select: (User, Member, Object), ComponentType.user_select: (User, ClientUser, Member, Object),
ComponentType.role_select: (Role, Object), ComponentType.role_select: (Role, Object),
ComponentType.channel_select: (GuildChannel, AppCommandChannel, AppCommandThread, Thread, Object), ComponentType.channel_select: (GuildChannel, AppCommandChannel, AppCommandThread, Thread, Object),
ComponentType.mentionable_select: (User, Member, Role, Object), ComponentType.mentionable_select: (User, ClientUser, Member, Role, Object),
} }
values: List[SelectDefaultValue] = [] values: List[SelectDefaultValue] = []
@ -149,7 +161,7 @@ def _handle_select_defaults(
object_type = obj.__class__ if not isinstance(obj, Object) else obj.type object_type = obj.__class__ if not isinstance(obj, Object) else obj.type
if object_type not in type_to_supported_classes[component_type]: if not _is_valid_object_type(object_type, component_type, type_to_supported_classes):
# TODO: split this into a util function # TODO: split this into a util function
supported_classes = [c.__name__ for c in type_to_supported_classes[component_type]] supported_classes = [c.__name__ for c in type_to_supported_classes[component_type]]
if len(supported_classes) > 2: if len(supported_classes) > 2:
@ -173,6 +185,9 @@ def _handle_select_defaults(
elif component_type is ComponentType.channel_select: elif component_type is ComponentType.channel_select:
object_type = GuildChannel object_type = GuildChannel
if issubclass(object_type, GuildChannel):
object_type = GuildChannel
values.append(SelectDefaultValue(id=obj.id, type=cls_to_type[object_type])) values.append(SelectDefaultValue(id=obj.id, type=cls_to_type[object_type]))
return values return values

Loading…
Cancel
Save