Browse Source

Fix GuildChannel subclasses not working with default select values

This also fixes it so ClientUser is respected as well
pull/9586/head
Rapptz 2 years ago
parent
commit
7f7a0acd21
  1. 39
      discord/ui/select.py

39
discord/ui/select.py

@ -55,7 +55,9 @@ from ..app_commands.namespace import Namespace
from ..member import Member
from ..object import Object
from ..role import Role
from ..user import User
from ..user import User, ClientUser
# from ..channel import TextChannel, VoiceChannel, StageChannel, CategoryChannel, ForumChannel
from ..abc import GuildChannel
from ..threads import Thread
@ -69,7 +71,7 @@ __all__ = (
)
if TYPE_CHECKING:
from typing_extensions import TypeAlias, Self
from typing_extensions import TypeAlias, Self, TypeGuard
from .view import View
from ..types.components import SelectMenu as SelectMenuPayload
@ -92,6 +94,7 @@ if TYPE_CHECKING:
Object,
Role,
Member,
ClientUser,
User,
GuildChannel,
AppCommandChannel,
@ -107,18 +110,26 @@ RoleSelectT = TypeVar('RoleSelectT', bound='RoleSelect[Any]')
ChannelSelectT = TypeVar('ChannelSelectT', bound='ChannelSelect[Any]')
MentionableSelectT = TypeVar('MentionableSelectT', bound='MentionableSelect[Any]')
SelectCallbackDecorator: TypeAlias = Callable[[ItemCallbackType[V, BaseSelectT]], BaseSelectT]
DefaultSelectComponentTypes = Literal[
ComponentType.user_select,
ComponentType.role_select,
ComponentType.channel_select,
ComponentType.mentionable_select,
]
selected_values: ContextVar[Dict[str, List[PossibleValue]]] = ContextVar('selected_values')
def _is_valid_object_type(
obj: Any,
component_type: DefaultSelectComponentTypes,
type_to_supported_classes: Dict[ValidSelectType, Tuple[Type[ValidDefaultValues], ...]],
) -> TypeGuard[Type[ValidDefaultValues]]:
return issubclass(obj, type_to_supported_classes[component_type])
def _handle_select_defaults(
defaults: Sequence[ValidDefaultValues],
component_type: Literal[
ComponentType.user_select,
ComponentType.role_select,
ComponentType.channel_select,
ComponentType.mentionable_select,
],
defaults: Sequence[ValidDefaultValues], component_type: DefaultSelectComponentTypes
) -> List[SelectDefaultValue]:
if not defaults or defaults is MISSING:
return []
@ -128,6 +139,7 @@ def _handle_select_defaults(
cls_to_type: Dict[Type[ValidDefaultValues], SelectDefaultValueType] = {
User: SelectDefaultValueType.user,
Member: SelectDefaultValueType.user,
ClientUser: SelectDefaultValueType.user,
Role: SelectDefaultValueType.role,
GuildChannel: SelectDefaultValueType.channel,
AppCommandChannel: SelectDefaultValueType.channel,
@ -135,10 +147,10 @@ def _handle_select_defaults(
Thread: SelectDefaultValueType.channel,
}
type_to_supported_classes: Dict[ValidSelectType, Tuple[Type[ValidDefaultValues], ...]] = {
ComponentType.user_select: (User, Member, Object),
ComponentType.user_select: (User, ClientUser, Member, Object),
ComponentType.role_select: (Role, Object),
ComponentType.channel_select: (GuildChannel, AppCommandChannel, AppCommandThread, Thread, Object),
ComponentType.mentionable_select: (User, Member, Role, Object),
ComponentType.mentionable_select: (User, ClientUser, Member, Role, Object),
}
values: List[SelectDefaultValue] = []
@ -149,7 +161,7 @@ def _handle_select_defaults(
object_type = obj.__class__ if not isinstance(obj, Object) else obj.type
if object_type not in type_to_supported_classes[component_type]:
if not _is_valid_object_type(object_type, component_type, type_to_supported_classes):
# TODO: split this into a util function
supported_classes = [c.__name__ for c in type_to_supported_classes[component_type]]
if len(supported_classes) > 2:
@ -173,6 +185,9 @@ def _handle_select_defaults(
elif component_type is ComponentType.channel_select:
object_type = GuildChannel
if issubclass(object_type, GuildChannel):
object_type = GuildChannel
values.append(SelectDefaultValue(id=obj.id, type=cls_to_type[object_type]))
return values

Loading…
Cancel
Save