Browse Source

Infer select type from callback annotation

pull/9141/head
Zephyrkul 2 years ago
committed by GitHub
parent
commit
b671958e11
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 143
      discord/ui/select.py

143
discord/ui/select.py

@ -22,7 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Callable, Union, Dict, overload
from typing import List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Callable, Union, Dict
from contextvars import ContextVar
import inspect
import os
@ -31,11 +31,8 @@ from .item import Item, ItemCallbackType
from ..enums import ChannelType, ComponentType
from ..partial_emoji import PartialEmoji
from ..emoji import Emoji
from ..utils import MISSING
from ..components import (
SelectOption,
SelectMenu,
)
from ..utils import MISSING, resolve_annotation
from ..components import SelectOption, SelectMenu
from ..app_commands.namespace import Namespace
__all__ = (
@ -72,11 +69,6 @@ if TYPE_CHECKING:
V = TypeVar('V', bound='View', covariant=True)
BaseSelectT = TypeVar('BaseSelectT', bound='BaseSelect')
SelectT = TypeVar('SelectT', bound='Select')
UserSelectT = TypeVar('UserSelectT', bound='UserSelect')
RoleSelectT = TypeVar('RoleSelectT', bound='RoleSelect')
ChannelSelectT = TypeVar('ChannelSelectT', bound='ChannelSelect')
MentionableSelectT = TypeVar('MentionableSelectT', bound='MentionableSelect')
SelectCallbackDecorator: TypeAlias = Callable[[ItemCallbackType[V, BaseSelectT]], BaseSelectT]
selected_values: ContextVar[Dict[str, List[PossibleValue]]] = ContextVar('selected_values')
@ -670,89 +662,32 @@ class ChannelSelect(BaseSelect[V]):
return super().values # type: ignore
@overload
def select(
*,
cls: Type[SelectT] = Select,
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = ...,
placeholder: Optional[str] = ...,
custom_id: str = ...,
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, SelectT]:
...
@overload
def select(
*,
cls: Type[UserSelectT],
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = ...,
placeholder: Optional[str] = ...,
custom_id: str = ...,
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, UserSelectT]:
...
@overload
def select(
*,
cls: Type[RoleSelectT],
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = ...,
placeholder: Optional[str] = ...,
custom_id: str = ...,
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, RoleSelectT]:
...
@overload
def select(
*,
cls: Type[ChannelSelectT],
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = ...,
placeholder: Optional[str] = ...,
custom_id: str = ...,
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, ChannelSelectT]:
...
@overload
def select(
*,
cls: Type[MentionableSelectT],
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = MISSING,
placeholder: Optional[str] = ...,
custom_id: str = ...,
min_values: int = ...,
max_values: int = ...,
disabled: bool = ...,
row: Optional[int] = ...,
) -> SelectCallbackDecorator[V, MentionableSelectT]:
...
def _get_select_callback_parameter(func: ItemCallbackType[V, BaseSelectT]) -> Type[BaseSelect]:
params = inspect.signature(func).parameters
if len(params) != 3:
raise TypeError(
f'select menu callback {func.__qualname__!r} requires 3 parameters, '
'the view instance (self), the discord.Interaction, and the select menu itself'
)
iterator = iter(params.values())
parameter = next(iterator)
for parameter in iterator:
pass
if parameter.annotation is parameter.empty:
return Select
resolved = resolve_annotation(parameter.annotation, func.__globals__, func.__globals__, {})
origin = getattr(resolved, '__origin__', resolved)
if origin is BaseSelect or not isinstance(origin, type) or not issubclass(origin, BaseSelect):
return Select
return origin
def select(
*,
cls: Type[BaseSelectT] = Select,
cls: Type[BaseSelectT] = Select if TYPE_CHECKING else MISSING,
options: List[SelectOption] = MISSING,
channel_types: List[ChannelType] = MISSING,
placeholder: Optional[str] = None,
@ -787,7 +722,10 @@ def select(
.. versionchanged:: 2.1
Added the following keyword-arguments: ``cls``, ``channel_types``
.. versionchanged:: 2.2
Now infers ``cls`` based on the callback if not supplied.
Example
---------
.. code-block:: python3
@ -802,10 +740,11 @@ def select(
------------
cls: Union[Type[:class:`discord.ui.Select`], Type[:class:`discord.ui.UserSelect`], Type[:class:`discord.ui.RoleSelect`], \
Type[:class:`discord.ui.MentionableSelect`], Type[:class:`discord.ui.ChannelSelect`]]
The class to use for the select menu. Defaults to :class:`discord.ui.Select`. You can use other
select types to display different select menus to the user. See the table above for the different
values you can get from each select type. Subclasses work as well, however the callback in the subclass will
get overridden.
The class to use for the select menu. Defaults to inferring the type from the
callback if available; otherwise defaults to :class:`discord.ui.Select`.
You can use other select types to display different select menus to the user.
See the table above for the different values you can get from each select type.
Subclasses work as well, however the callback in the subclass will get overridden.
placeholder: Optional[:class:`str`]
The placeholder text that is shown if nothing is selected, if any.
custom_id: :class:`str`
@ -836,11 +775,15 @@ def select(
def decorator(func: ItemCallbackType[V, BaseSelectT]) -> ItemCallbackType[V, BaseSelectT]:
if not inspect.iscoroutinefunction(func):
raise TypeError('select function must be a coroutine function')
if not issubclass(cls, BaseSelect):
if cls is MISSING:
callback_cls = _get_select_callback_parameter(func)
else:
callback_cls = getattr(cls, '__origin__', cls)
if not issubclass(callback_cls, BaseSelect):
supported_classes = ", ".join(["ChannelSelect", "MentionableSelect", "RoleSelect", "Select", "UserSelect"])
raise TypeError(f'cls must be one of {supported_classes} or a subclass of one of them, not {cls!r}.')
raise TypeError(f'cls must be one of {supported_classes} or a subclass of one of them, not {callback_cls!r}.')
func.__discord_ui_model_type__ = cls
func.__discord_ui_model_type__ = callback_cls
func.__discord_ui_model_kwargs__ = {
'placeholder': placeholder,
'custom_id': custom_id,
@ -849,9 +792,9 @@ def select(
'max_values': max_values,
'disabled': disabled,
}
if issubclass(cls, Select):
if issubclass(callback_cls, Select):
func.__discord_ui_model_kwargs__['options'] = options
if issubclass(cls, ChannelSelect):
if issubclass(callback_cls, ChannelSelect):
func.__discord_ui_model_kwargs__['channel_types'] = channel_types
return func

Loading…
Cancel
Save