Browse Source

Infer select type from callback annotation

pull/9141/head
Zephyrkul 2 years ago
committed by GitHub
parent
commit
b671958e11
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 143
      discord/ui/select.py

143
discord/ui/select.py

@ -22,7 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
from typing import List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Callable, Union, Dict, overload from typing import List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Callable, Union, Dict
from contextvars import ContextVar from contextvars import ContextVar
import inspect import inspect
import os import os
@ -31,11 +31,8 @@ from .item import Item, ItemCallbackType
from ..enums import ChannelType, ComponentType from ..enums import ChannelType, ComponentType
from ..partial_emoji import PartialEmoji from ..partial_emoji import PartialEmoji
from ..emoji import Emoji from ..emoji import Emoji
from ..utils import MISSING from ..utils import MISSING, resolve_annotation
from ..components import ( from ..components import SelectOption, SelectMenu
SelectOption,
SelectMenu,
)
from ..app_commands.namespace import Namespace from ..app_commands.namespace import Namespace
__all__ = ( __all__ = (
@ -72,11 +69,6 @@ if TYPE_CHECKING:
V = TypeVar('V', bound='View', covariant=True) V = TypeVar('V', bound='View', covariant=True)
BaseSelectT = TypeVar('BaseSelectT', bound='BaseSelect') BaseSelectT = TypeVar('BaseSelectT', bound='BaseSelect')
SelectT = TypeVar('SelectT', bound='Select')
UserSelectT = TypeVar('UserSelectT', bound='UserSelect')
RoleSelectT = TypeVar('RoleSelectT', bound='RoleSelect')
ChannelSelectT = TypeVar('ChannelSelectT', bound='ChannelSelect')
MentionableSelectT = TypeVar('MentionableSelectT', bound='MentionableSelect')
SelectCallbackDecorator: TypeAlias = Callable[[ItemCallbackType[V, BaseSelectT]], BaseSelectT] SelectCallbackDecorator: TypeAlias = Callable[[ItemCallbackType[V, BaseSelectT]], BaseSelectT]
selected_values: ContextVar[Dict[str, List[PossibleValue]]] = ContextVar('selected_values') selected_values: ContextVar[Dict[str, List[PossibleValue]]] = ContextVar('selected_values')
@ -670,89 +662,32 @@ class ChannelSelect(BaseSelect[V]):
return super().values # type: ignore return super().values # type: ignore
@overload def _get_select_callback_parameter(func: ItemCallbackType[V, BaseSelectT]) -> Type[BaseSelect]:
def select( params = inspect.signature(func).parameters
*, if len(params) != 3:
cls: Type[SelectT] = Select, raise TypeError(
options: List[SelectOption] = MISSING, f'select menu callback {func.__qualname__!r} requires 3 parameters, '
channel_types: List[ChannelType] = ..., 'the view instance (self), the discord.Interaction, and the select menu itself'
placeholder: Optional[str] = ..., )
custom_id: str = ...,
min_values: int = ..., iterator = iter(params.values())
max_values: int = ..., parameter = next(iterator)
disabled: bool = ..., for parameter in iterator:
row: Optional[int] = ..., pass
) -> SelectCallbackDecorator[V, SelectT]:
... if parameter.annotation is parameter.empty:
return Select
@overload resolved = resolve_annotation(parameter.annotation, func.__globals__, func.__globals__, {})
def select( origin = getattr(resolved, '__origin__', resolved)
*, if origin is BaseSelect or not isinstance(origin, type) or not issubclass(origin, BaseSelect):
cls: Type[UserSelectT], return Select
options: List[SelectOption] = MISSING, return origin
channel_types: List[ChannelType] = ...,
placeholder: Optional[str] = ...,
custom_id: str = ...,
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, UserSelectT]:
...
@overload
def select(
*,
cls: Type[RoleSelectT],
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = ...,
placeholder: Optional[str] = ...,
custom_id: str = ...,
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, RoleSelectT]:
...
@overload
def select(
*,
cls: Type[ChannelSelectT],
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = ...,
placeholder: Optional[str] = ...,
custom_id: str = ...,
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, ChannelSelectT]:
...
@overload
def select(
*,
cls: Type[MentionableSelectT],
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = MISSING,
placeholder: Optional[str] = ...,
custom_id: str = ...,
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, MentionableSelectT]:
...
def select( def select(
*, *,
cls: Type[BaseSelectT] = Select, cls: Type[BaseSelectT] = Select if TYPE_CHECKING else MISSING,
options: List[SelectOption] = MISSING, options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = MISSING, channel_types: List[ChannelType] = MISSING,
placeholder: Optional[str] = None, placeholder: Optional[str] = None,
@ -787,7 +722,10 @@ def select(
.. versionchanged:: 2.1 .. versionchanged:: 2.1
Added the following keyword-arguments: ``cls``, ``channel_types`` Added the following keyword-arguments: ``cls``, ``channel_types``
.. versionchanged:: 2.2
Now infers ``cls`` based on the callback if not supplied.
Example Example
--------- ---------
.. code-block:: python3 .. code-block:: python3
@ -802,10 +740,11 @@ def select(
------------ ------------
cls: Union[Type[:class:`discord.ui.Select`], Type[:class:`discord.ui.UserSelect`], Type[:class:`discord.ui.RoleSelect`], \ cls: Union[Type[:class:`discord.ui.Select`], Type[:class:`discord.ui.UserSelect`], Type[:class:`discord.ui.RoleSelect`], \
Type[:class:`discord.ui.MentionableSelect`], Type[:class:`discord.ui.ChannelSelect`]] Type[:class:`discord.ui.MentionableSelect`], Type[:class:`discord.ui.ChannelSelect`]]
The class to use for the select menu. Defaults to :class:`discord.ui.Select`. You can use other The class to use for the select menu. Defaults to inferring the type from the
select types to display different select menus to the user. See the table above for the different callback if available; otherwise defaults to :class:`discord.ui.Select`.
values you can get from each select type. Subclasses work as well, however the callback in the subclass will You can use other select types to display different select menus to the user.
get overridden. See the table above for the different values you can get from each select type.
Subclasses work as well, however the callback in the subclass will get overridden.
placeholder: Optional[:class:`str`] placeholder: Optional[:class:`str`]
The placeholder text that is shown if nothing is selected, if any. The placeholder text that is shown if nothing is selected, if any.
custom_id: :class:`str` custom_id: :class:`str`
@ -836,11 +775,15 @@ def select(
def decorator(func: ItemCallbackType[V, BaseSelectT]) -> ItemCallbackType[V, BaseSelectT]: def decorator(func: ItemCallbackType[V, BaseSelectT]) -> ItemCallbackType[V, BaseSelectT]:
if not inspect.iscoroutinefunction(func): if not inspect.iscoroutinefunction(func):
raise TypeError('select function must be a coroutine function') raise TypeError('select function must be a coroutine function')
if not issubclass(cls, BaseSelect): if cls is MISSING:
callback_cls = _get_select_callback_parameter(func)
else:
callback_cls = getattr(cls, '__origin__', cls)
if not issubclass(callback_cls, BaseSelect):
supported_classes = ", ".join(["ChannelSelect", "MentionableSelect", "RoleSelect", "Select", "UserSelect"]) supported_classes = ", ".join(["ChannelSelect", "MentionableSelect", "RoleSelect", "Select", "UserSelect"])
raise TypeError(f'cls must be one of {supported_classes} or a subclass of one of them, not {cls!r}.') raise TypeError(f'cls must be one of {supported_classes} or a subclass of one of them, not {callback_cls!r}.')
func.__discord_ui_model_type__ = cls func.__discord_ui_model_type__ = callback_cls
func.__discord_ui_model_kwargs__ = { func.__discord_ui_model_kwargs__ = {
'placeholder': placeholder, 'placeholder': placeholder,
'custom_id': custom_id, 'custom_id': custom_id,
@ -849,9 +792,9 @@ def select(
'max_values': max_values, 'max_values': max_values,
'disabled': disabled, 'disabled': disabled,
} }
if issubclass(cls, Select): if issubclass(callback_cls, Select):
func.__discord_ui_model_kwargs__['options'] = options func.__discord_ui_model_kwargs__['options'] = options
if issubclass(cls, ChannelSelect): if issubclass(callback_cls, ChannelSelect):
func.__discord_ui_model_kwargs__['channel_types'] = channel_types func.__discord_ui_model_kwargs__['channel_types'] = channel_types
return func return func

Loading…
Cancel
Save