diff --git a/discord/ui/select.py b/discord/ui/select.py index 45fca382d..23791ee0c 100644 --- a/discord/ui/select.py +++ b/discord/ui/select.py @@ -22,7 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from __future__ import annotations -from typing import List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Callable, Union, Dict, overload +from typing import List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Callable, Union, Dict from contextvars import ContextVar import inspect import os @@ -31,11 +31,8 @@ from .item import Item, ItemCallbackType from ..enums import ChannelType, ComponentType from ..partial_emoji import PartialEmoji from ..emoji import Emoji -from ..utils import MISSING -from ..components import ( - SelectOption, - SelectMenu, -) +from ..utils import MISSING, resolve_annotation +from ..components import SelectOption, SelectMenu from ..app_commands.namespace import Namespace __all__ = ( @@ -72,11 +69,6 @@ if TYPE_CHECKING: V = TypeVar('V', bound='View', covariant=True) BaseSelectT = TypeVar('BaseSelectT', bound='BaseSelect') -SelectT = TypeVar('SelectT', bound='Select') -UserSelectT = TypeVar('UserSelectT', bound='UserSelect') -RoleSelectT = TypeVar('RoleSelectT', bound='RoleSelect') -ChannelSelectT = TypeVar('ChannelSelectT', bound='ChannelSelect') -MentionableSelectT = TypeVar('MentionableSelectT', bound='MentionableSelect') SelectCallbackDecorator: TypeAlias = Callable[[ItemCallbackType[V, BaseSelectT]], BaseSelectT] selected_values: ContextVar[Dict[str, List[PossibleValue]]] = ContextVar('selected_values') @@ -670,89 +662,32 @@ class ChannelSelect(BaseSelect[V]): return super().values # type: ignore -@overload -def select( - *, - cls: Type[SelectT] = Select, - options: List[SelectOption] = MISSING, - channel_types: List[ChannelType] = ..., - placeholder: Optional[str] = ..., - custom_id: str = ..., - min_values: int = ..., - max_values: int = ..., - disabled: bool = ..., - row: Optional[int] = ..., -) -> SelectCallbackDecorator[V, SelectT]: - ... - - -@overload -def select( - *, - cls: Type[UserSelectT], - options: List[SelectOption] = MISSING, - channel_types: List[ChannelType] = ..., - placeholder: Optional[str] = ..., - custom_id: str = ..., - min_values: int = ..., - max_values: int = ..., - disabled: bool = ..., - row: Optional[int] = ..., -) -> SelectCallbackDecorator[V, UserSelectT]: - ... - - -@overload -def select( - *, - cls: Type[RoleSelectT], - options: List[SelectOption] = MISSING, - channel_types: List[ChannelType] = ..., - placeholder: Optional[str] = ..., - custom_id: str = ..., - min_values: int = ..., - max_values: int = ..., - disabled: bool = ..., - row: Optional[int] = ..., -) -> SelectCallbackDecorator[V, RoleSelectT]: - ... - - -@overload -def select( - *, - cls: Type[ChannelSelectT], - options: List[SelectOption] = MISSING, - channel_types: List[ChannelType] = ..., - placeholder: Optional[str] = ..., - custom_id: str = ..., - min_values: int = ..., - max_values: int = ..., - disabled: bool = ..., - row: Optional[int] = ..., -) -> SelectCallbackDecorator[V, ChannelSelectT]: - ... - - -@overload -def select( - *, - cls: Type[MentionableSelectT], - options: List[SelectOption] = MISSING, - channel_types: List[ChannelType] = MISSING, - placeholder: Optional[str] = ..., - custom_id: str = ..., - min_values: int = ..., - max_values: int = ..., - disabled: bool = ..., - row: Optional[int] = ..., -) -> SelectCallbackDecorator[V, MentionableSelectT]: - ... +def _get_select_callback_parameter(func: ItemCallbackType[V, BaseSelectT]) -> Type[BaseSelect]: + params = inspect.signature(func).parameters + if len(params) != 3: + raise TypeError( + f'select menu callback {func.__qualname__!r} requires 3 parameters, ' + 'the view instance (self), the discord.Interaction, and the select menu itself' + ) + + iterator = iter(params.values()) + parameter = next(iterator) + for parameter in iterator: + pass + + if parameter.annotation is parameter.empty: + return Select + + resolved = resolve_annotation(parameter.annotation, func.__globals__, func.__globals__, {}) + origin = getattr(resolved, '__origin__', resolved) + if origin is BaseSelect or not isinstance(origin, type) or not issubclass(origin, BaseSelect): + return Select + return origin def select( *, - cls: Type[BaseSelectT] = Select, + cls: Type[BaseSelectT] = Select if TYPE_CHECKING else MISSING, options: List[SelectOption] = MISSING, channel_types: List[ChannelType] = MISSING, placeholder: Optional[str] = None, @@ -787,7 +722,10 @@ def select( .. versionchanged:: 2.1 Added the following keyword-arguments: ``cls``, ``channel_types`` - + + .. versionchanged:: 2.2 + Now infers ``cls`` based on the callback if not supplied. + Example --------- .. code-block:: python3 @@ -802,10 +740,11 @@ def select( ------------ cls: Union[Type[:class:`discord.ui.Select`], Type[:class:`discord.ui.UserSelect`], Type[:class:`discord.ui.RoleSelect`], \ Type[:class:`discord.ui.MentionableSelect`], Type[:class:`discord.ui.ChannelSelect`]] - The class to use for the select menu. Defaults to :class:`discord.ui.Select`. You can use other - select types to display different select menus to the user. See the table above for the different - values you can get from each select type. Subclasses work as well, however the callback in the subclass will - get overridden. + The class to use for the select menu. Defaults to inferring the type from the + callback if available; otherwise defaults to :class:`discord.ui.Select`. + You can use other select types to display different select menus to the user. + See the table above for the different values you can get from each select type. + Subclasses work as well, however the callback in the subclass will get overridden. placeholder: Optional[:class:`str`] The placeholder text that is shown if nothing is selected, if any. custom_id: :class:`str` @@ -836,11 +775,15 @@ def select( def decorator(func: ItemCallbackType[V, BaseSelectT]) -> ItemCallbackType[V, BaseSelectT]: if not inspect.iscoroutinefunction(func): raise TypeError('select function must be a coroutine function') - if not issubclass(cls, BaseSelect): + if cls is MISSING: + callback_cls = _get_select_callback_parameter(func) + else: + callback_cls = getattr(cls, '__origin__', cls) + if not issubclass(callback_cls, BaseSelect): supported_classes = ", ".join(["ChannelSelect", "MentionableSelect", "RoleSelect", "Select", "UserSelect"]) - raise TypeError(f'cls must be one of {supported_classes} or a subclass of one of them, not {cls!r}.') + raise TypeError(f'cls must be one of {supported_classes} or a subclass of one of them, not {callback_cls!r}.') - func.__discord_ui_model_type__ = cls + func.__discord_ui_model_type__ = callback_cls func.__discord_ui_model_kwargs__ = { 'placeholder': placeholder, 'custom_id': custom_id, @@ -849,9 +792,9 @@ def select( 'max_values': max_values, 'disabled': disabled, } - if issubclass(cls, Select): + if issubclass(callback_cls, Select): func.__discord_ui_model_kwargs__['options'] = options - if issubclass(cls, ChannelSelect): + if issubclass(callback_cls, ChannelSelect): func.__discord_ui_model_kwargs__['channel_types'] = channel_types return func