diff --git a/discord/ui/select.py b/discord/ui/select.py index 23791ee0c..be489526f 100644 --- a/discord/ui/select.py +++ b/discord/ui/select.py @@ -22,7 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from __future__ import annotations -from typing import List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Callable, Union, Dict +from typing import List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Callable, Union, Dict, overload from contextvars import ContextVar import inspect import os @@ -31,8 +31,11 @@ from .item import Item, ItemCallbackType from ..enums import ChannelType, ComponentType from ..partial_emoji import PartialEmoji from ..emoji import Emoji -from ..utils import MISSING, resolve_annotation -from ..components import SelectOption, SelectMenu +from ..utils import MISSING +from ..components import ( + SelectOption, + SelectMenu, +) from ..app_commands.namespace import Namespace __all__ = ( @@ -69,6 +72,11 @@ if TYPE_CHECKING: V = TypeVar('V', bound='View', covariant=True) BaseSelectT = TypeVar('BaseSelectT', bound='BaseSelect') +SelectT = TypeVar('SelectT', bound='Select') +UserSelectT = TypeVar('UserSelectT', bound='UserSelect') +RoleSelectT = TypeVar('RoleSelectT', bound='RoleSelect') +ChannelSelectT = TypeVar('ChannelSelectT', bound='ChannelSelect') +MentionableSelectT = TypeVar('MentionableSelectT', bound='MentionableSelect') SelectCallbackDecorator: TypeAlias = Callable[[ItemCallbackType[V, BaseSelectT]], BaseSelectT] selected_values: ContextVar[Dict[str, List[PossibleValue]]] = ContextVar('selected_values') @@ -662,32 +670,89 @@ class ChannelSelect(BaseSelect[V]): return super().values # type: ignore -def _get_select_callback_parameter(func: ItemCallbackType[V, BaseSelectT]) -> Type[BaseSelect]: - params = inspect.signature(func).parameters - if len(params) != 3: - raise TypeError( - f'select menu callback {func.__qualname__!r} requires 3 parameters, ' - 'the view instance (self), the discord.Interaction, and the select menu itself' - ) - - iterator = iter(params.values()) - parameter = next(iterator) - for parameter in iterator: - pass - - if parameter.annotation is parameter.empty: - return Select - - resolved = resolve_annotation(parameter.annotation, func.__globals__, func.__globals__, {}) - origin = getattr(resolved, '__origin__', resolved) - if origin is BaseSelect or not isinstance(origin, type) or not issubclass(origin, BaseSelect): - return Select - return origin +@overload +def select( + *, + cls: Type[SelectT] = Select, + options: List[SelectOption] = MISSING, + channel_types: List[ChannelType] = ..., + placeholder: Optional[str] = ..., + custom_id: str = ..., + min_values: int = ..., + max_values: int = ..., + disabled: bool = ..., + row: Optional[int] = ..., +) -> SelectCallbackDecorator[V, SelectT]: + ... + + +@overload +def select( + *, + cls: Type[UserSelectT], + options: List[SelectOption] = MISSING, + channel_types: List[ChannelType] = ..., + placeholder: Optional[str] = ..., + custom_id: str = ..., + min_values: int = ..., + max_values: int = ..., + disabled: bool = ..., + row: Optional[int] = ..., +) -> SelectCallbackDecorator[V, UserSelectT]: + ... + + +@overload +def select( + *, + cls: Type[RoleSelectT], + options: List[SelectOption] = MISSING, + channel_types: List[ChannelType] = ..., + placeholder: Optional[str] = ..., + custom_id: str = ..., + min_values: int = ..., + max_values: int = ..., + disabled: bool = ..., + row: Optional[int] = ..., +) -> SelectCallbackDecorator[V, RoleSelectT]: + ... + + +@overload +def select( + *, + cls: Type[ChannelSelectT], + options: List[SelectOption] = MISSING, + channel_types: List[ChannelType] = ..., + placeholder: Optional[str] = ..., + custom_id: str = ..., + min_values: int = ..., + max_values: int = ..., + disabled: bool = ..., + row: Optional[int] = ..., +) -> SelectCallbackDecorator[V, ChannelSelectT]: + ... + + +@overload +def select( + *, + cls: Type[MentionableSelectT], + options: List[SelectOption] = MISSING, + channel_types: List[ChannelType] = MISSING, + placeholder: Optional[str] = ..., + custom_id: str = ..., + min_values: int = ..., + max_values: int = ..., + disabled: bool = ..., + row: Optional[int] = ..., +) -> SelectCallbackDecorator[V, MentionableSelectT]: + ... def select( *, - cls: Type[BaseSelectT] = Select if TYPE_CHECKING else MISSING, + cls: Type[BaseSelectT] = Select, options: List[SelectOption] = MISSING, channel_types: List[ChannelType] = MISSING, placeholder: Optional[str] = None, @@ -722,10 +787,7 @@ def select( .. versionchanged:: 2.1 Added the following keyword-arguments: ``cls``, ``channel_types`` - - .. versionchanged:: 2.2 - Now infers ``cls`` based on the callback if not supplied. - + Example --------- .. code-block:: python3 @@ -740,11 +802,10 @@ def select( ------------ cls: Union[Type[:class:`discord.ui.Select`], Type[:class:`discord.ui.UserSelect`], Type[:class:`discord.ui.RoleSelect`], \ Type[:class:`discord.ui.MentionableSelect`], Type[:class:`discord.ui.ChannelSelect`]] - The class to use for the select menu. Defaults to inferring the type from the - callback if available; otherwise defaults to :class:`discord.ui.Select`. - You can use other select types to display different select menus to the user. - See the table above for the different values you can get from each select type. - Subclasses work as well, however the callback in the subclass will get overridden. + The class to use for the select menu. Defaults to :class:`discord.ui.Select`. You can use other + select types to display different select menus to the user. See the table above for the different + values you can get from each select type. Subclasses work as well, however the callback in the subclass will + get overridden. placeholder: Optional[:class:`str`] The placeholder text that is shown if nothing is selected, if any. custom_id: :class:`str` @@ -775,13 +836,10 @@ def select( def decorator(func: ItemCallbackType[V, BaseSelectT]) -> ItemCallbackType[V, BaseSelectT]: if not inspect.iscoroutinefunction(func): raise TypeError('select function must be a coroutine function') - if cls is MISSING: - callback_cls = _get_select_callback_parameter(func) - else: - callback_cls = getattr(cls, '__origin__', cls) + callback_cls = getattr(cls, '__origin__', cls) if not issubclass(callback_cls, BaseSelect): supported_classes = ", ".join(["ChannelSelect", "MentionableSelect", "RoleSelect", "Select", "UserSelect"]) - raise TypeError(f'cls must be one of {supported_classes} or a subclass of one of them, not {callback_cls!r}.') + raise TypeError(f'cls must be one of {supported_classes} or a subclass of one of them, not {cls!r}.') func.__discord_ui_model_type__ = callback_cls func.__discord_ui_model_kwargs__ = {