Browse Source

chore: some bunch fixes and make interaction_check's work on every item

pull/10166/head
DA-344 3 months ago
parent
commit
cf949c689f
  1. 26
      discord/ui/action_row.py
  2. 10
      discord/ui/button.py
  3. 7
      discord/ui/container.py
  4. 13
      discord/ui/item.py
  5. 16
      discord/ui/select.py
  6. 12
      discord/ui/view.py

26
discord/ui/action_row.py

@ -75,8 +75,8 @@ __all__ = ('ActionRow',)
class _ActionRowCallback: class _ActionRowCallback:
__slots__ = ('row', 'callback', 'item') __slots__ = ('row', 'callback', 'item')
def __init__(self, callback: ItemCallbackType[Any, Any], row: ActionRow, item: Item[Any]) -> None: def __init__(self, callback: ItemCallbackType[Any], row: ActionRow, item: Item[Any]) -> None:
self.callback: ItemCallbackType[Any, Any] = callback self.callback: ItemCallbackType[Any] = callback
self.row: ActionRow = row self.row: ActionRow = row
self.item: Item[Any] = item self.item: Item[Any] = item
@ -97,7 +97,7 @@ class ActionRow(Item[V]):
The ID of this component. This must be unique across the view. The ID of this component. This must be unique across the view.
""" """
__action_row_children_items__: ClassVar[List[ItemCallbackType[Any, Any]]] = [] __action_row_children_items__: ClassVar[List[ItemCallbackType[Any]]] = []
__discord_ui_action_row__: ClassVar[bool] = True __discord_ui_action_row__: ClassVar[bool] = True
__discord_ui_update_view__: ClassVar[bool] = True __discord_ui_update_view__: ClassVar[bool] = True
@ -110,7 +110,7 @@ class ActionRow(Item[V]):
def __init_subclass__(cls) -> None: def __init_subclass__(cls) -> None:
super().__init_subclass__() super().__init_subclass__()
children: Dict[str, ItemCallbackType[Any, Any]] = {} children: Dict[str, ItemCallbackType[Any]] = {}
for base in reversed(cls.__mro__): for base in reversed(cls.__mro__):
for name, member in base.__dict__.items(): for name, member in base.__dict__.items():
if hasattr(member, '__discord_ui_model_type__'): if hasattr(member, '__discord_ui_model_type__'):
@ -269,7 +269,7 @@ class ActionRow(Item[V]):
disabled: bool = False, disabled: bool = False,
style: ButtonStyle = ButtonStyle.secondary, style: ButtonStyle = ButtonStyle.secondary,
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None, emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
) -> Callable[[ItemCallbackType[V, Button[V]]], Button[V]]: ) -> Callable[[ItemCallbackType[Button[V]]], Button[V]]:
"""A decorator that attaches a button to a component. """A decorator that attaches a button to a component.
The function being decorated should have three parameters, ``self`` representing The function being decorated should have three parameters, ``self`` representing
@ -302,7 +302,7 @@ class ActionRow(Item[V]):
or a full :class:`.Emoji`. or a full :class:`.Emoji`.
""" """
def decorator(func: ItemCallbackType[V, Button[V]]) -> ItemCallbackType[V, Button[V]]: def decorator(func: ItemCallbackType[Button[V]]) -> ItemCallbackType[Button[V]]:
ret = _button( ret = _button(
label=label, label=label,
custom_id=custom_id, custom_id=custom_id,
@ -328,7 +328,7 @@ class ActionRow(Item[V]):
min_values: int = ..., min_values: int = ...,
max_values: int = ..., max_values: int = ...,
disabled: bool = ..., disabled: bool = ...,
) -> SelectCallbackDecorator[V, SelectT]: ) -> SelectCallbackDecorator[SelectT]:
... ...
@overload @overload
@ -344,7 +344,7 @@ class ActionRow(Item[V]):
max_values: int = ..., max_values: int = ...,
disabled: bool = ..., disabled: bool = ...,
default_values: Sequence[ValidDefaultValues] = ..., default_values: Sequence[ValidDefaultValues] = ...,
) -> SelectCallbackDecorator[V, UserSelectT]: ) -> SelectCallbackDecorator[UserSelectT]:
... ...
@overload @overload
@ -360,7 +360,7 @@ class ActionRow(Item[V]):
max_values: int = ..., max_values: int = ...,
disabled: bool = ..., disabled: bool = ...,
default_values: Sequence[ValidDefaultValues] = ..., default_values: Sequence[ValidDefaultValues] = ...,
) -> SelectCallbackDecorator[V, RoleSelectT]: ) -> SelectCallbackDecorator[RoleSelectT]:
... ...
@overload @overload
@ -376,7 +376,7 @@ class ActionRow(Item[V]):
max_values: int = ..., max_values: int = ...,
disabled: bool = ..., disabled: bool = ...,
default_values: Sequence[ValidDefaultValues] = ..., default_values: Sequence[ValidDefaultValues] = ...,
) -> SelectCallbackDecorator[V, ChannelSelectT]: ) -> SelectCallbackDecorator[ChannelSelectT]:
... ...
@overload @overload
@ -392,7 +392,7 @@ class ActionRow(Item[V]):
max_values: int = ..., max_values: int = ...,
disabled: bool = ..., disabled: bool = ...,
default_values: Sequence[ValidDefaultValues] = ..., default_values: Sequence[ValidDefaultValues] = ...,
) -> SelectCallbackDecorator[V, MentionableSelectT]: ) -> SelectCallbackDecorator[MentionableSelectT]:
... ...
def select( def select(
@ -407,7 +407,7 @@ class ActionRow(Item[V]):
max_values: int = 1, max_values: int = 1,
disabled: bool = False, disabled: bool = False,
default_values: Sequence[ValidDefaultValues] = MISSING, default_values: Sequence[ValidDefaultValues] = MISSING,
) -> SelectCallbackDecorator[V, BaseSelectT]: ) -> SelectCallbackDecorator[BaseSelectT]:
"""A decorator that attaches a select menu to a component. """A decorator that attaches a select menu to a component.
The function being decorated should have three parameters, ``self`` representing The function being decorated should have three parameters, ``self`` representing
@ -477,7 +477,7 @@ class ActionRow(Item[V]):
Number of items must be in range of ``min_values`` and ``max_values``. Number of items must be in range of ``min_values`` and ``max_values``.
""" """
def decorator(func: ItemCallbackType[V, BaseSelectT]) -> ItemCallbackType[V, BaseSelectT]: def decorator(func: ItemCallbackType[BaseSelectT]) -> ItemCallbackType[BaseSelectT]:
r = _select( # type: ignore r = _select( # type: ignore
cls=cls, # type: ignore cls=cls, # type: ignore
placeholder=placeholder, placeholder=placeholder,

10
discord/ui/button.py

@ -281,7 +281,8 @@ def button(
style: ButtonStyle = ButtonStyle.secondary, style: ButtonStyle = ButtonStyle.secondary,
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None, emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
row: Optional[int] = None, row: Optional[int] = None,
) -> Callable[[ItemCallbackType[V, Button[V]]], Button[V]]: id: Optional[int] = None,
) -> Callable[[ItemCallbackType[Button[V]]], Button[V]]:
"""A decorator that attaches a button to a component. """A decorator that attaches a button to a component.
The function being decorated should have three parameters, ``self`` representing The function being decorated should have three parameters, ``self`` representing
@ -318,9 +319,13 @@ def button(
like to control the relative positioning of the row then passing an index is advised. like to control the relative positioning of the row then passing an index is advised.
For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic
ordering. The row number must be between 0 and 4 (i.e. zero indexed). ordering. The row number must be between 0 and 4 (i.e. zero indexed).
id: Optional[:class:`int`]
The ID of this component. This must be unique across the view.
.. versionadded:: 2.6
""" """
def decorator(func: ItemCallbackType[V, Button[V]]) -> ItemCallbackType[V, Button[V]]: def decorator(func: ItemCallbackType[Button[V]]) -> ItemCallbackType[Button[V]]:
if not inspect.iscoroutinefunction(func): if not inspect.iscoroutinefunction(func):
raise TypeError('button function must be a coroutine function') raise TypeError('button function must be a coroutine function')
@ -334,6 +339,7 @@ def button(
'emoji': emoji, 'emoji': emoji,
'row': row, 'row': row,
'sku_id': None, 'sku_id': None,
'id': id,
} }
return func return func

7
discord/ui/container.py

@ -46,8 +46,8 @@ __all__ = ('Container',)
class _ContainerCallback: class _ContainerCallback:
__slots__ = ('container', 'callback', 'item') __slots__ = ('container', 'callback', 'item')
def __init__(self, callback: ItemCallbackType[Any, Any], container: Container, item: Item[Any]) -> None: def __init__(self, callback: ItemCallbackType[Any], container: Container, item: Item[Any]) -> None:
self.callback: ItemCallbackType[Any, Any] = callback self.callback: ItemCallbackType[Any] = callback
self.container: Container = container self.container: Container = container
self.item: Item[Any] = item self.item: Item[Any] = item
@ -63,7 +63,7 @@ class Container(Item[V]):
Parameters Parameters
---------- ----------
children: List[:class:`Item`] children: List[:class:`Item`]
The initial children or :class:`View` s of this container. Can have up to 10 The initial children of this container. Can have up to 10
items. items.
accent_colour: Optional[:class:`.Colour`] accent_colour: Optional[:class:`.Colour`]
The colour of the container. Defaults to ``None``. The colour of the container. Defaults to ``None``.
@ -124,6 +124,7 @@ class Container(Item[V]):
if getattr(raw, '__discord_ui_section__', False) and raw.accessory.is_dispatchable(): # type: ignore if getattr(raw, '__discord_ui_section__', False) and raw.accessory.is_dispatchable(): # type: ignore
self.__dispatchable.append(raw.accessory) # type: ignore self.__dispatchable.append(raw.accessory) # type: ignore
elif getattr(raw, '__discord_ui_action_row__', False) and raw.is_dispatchable(): elif getattr(raw, '__discord_ui_action_row__', False) and raw.is_dispatchable():
raw._parent = self # type: ignore
self.__dispatchable.extend(raw._children) # type: ignore self.__dispatchable.extend(raw._children) # type: ignore
else: else:
# action rows can be created inside containers, and then callbacks can exist here # action rows can be created inside containers, and then callbacks can exist here

13
discord/ui/item.py

@ -43,7 +43,7 @@ if TYPE_CHECKING:
I = TypeVar('I', bound='Item[Any]') I = TypeVar('I', bound='Item[Any]')
V = TypeVar('V', bound='BaseView', covariant=True) V = TypeVar('V', bound='BaseView', covariant=True)
ItemCallbackType = Callable[[V, Interaction[Any], I], Coroutine[Any, Any, Any]] ItemCallbackType = Callable[[Any, Interaction[Any], I], Coroutine[Any, Any, Any]]
class Item(Generic[V]): class Item(Generic[V]):
@ -151,6 +151,17 @@ class Item(Generic[V]):
def id(self, value: Optional[int]) -> None: def id(self, value: Optional[int]) -> None:
self._id = value self._id = value
async def _run_checks(self, interaction: Interaction[ClientT]) -> bool:
can_run = await self.interaction_check(interaction)
if can_run:
parent = getattr(self, '_parent', None)
if parent is not None:
can_run = await parent._run_checks(interaction)
return can_run
async def callback(self, interaction: Interaction[ClientT]) -> Any: async def callback(self, interaction: Interaction[ClientT]) -> Any:
"""|coro| """|coro|

16
discord/ui/select.py

@ -109,7 +109,7 @@ UserSelectT = TypeVar('UserSelectT', bound='UserSelect[Any]')
RoleSelectT = TypeVar('RoleSelectT', bound='RoleSelect[Any]') RoleSelectT = TypeVar('RoleSelectT', bound='RoleSelect[Any]')
ChannelSelectT = TypeVar('ChannelSelectT', bound='ChannelSelect[Any]') ChannelSelectT = TypeVar('ChannelSelectT', bound='ChannelSelect[Any]')
MentionableSelectT = TypeVar('MentionableSelectT', bound='MentionableSelect[Any]') MentionableSelectT = TypeVar('MentionableSelectT', bound='MentionableSelect[Any]')
SelectCallbackDecorator: TypeAlias = Callable[[ItemCallbackType[V, BaseSelectT]], BaseSelectT] SelectCallbackDecorator: TypeAlias = Callable[[ItemCallbackType[BaseSelectT]], BaseSelectT]
DefaultSelectComponentTypes = Literal[ DefaultSelectComponentTypes = Literal[
ComponentType.user_select, ComponentType.user_select,
ComponentType.role_select, ComponentType.role_select,
@ -936,7 +936,7 @@ def select(
disabled: bool = ..., disabled: bool = ...,
row: Optional[int] = ..., row: Optional[int] = ...,
id: Optional[int] = ..., id: Optional[int] = ...,
) -> SelectCallbackDecorator[V, SelectT]: ) -> SelectCallbackDecorator[SelectT]:
... ...
@ -954,7 +954,7 @@ def select(
default_values: Sequence[ValidDefaultValues] = ..., default_values: Sequence[ValidDefaultValues] = ...,
row: Optional[int] = ..., row: Optional[int] = ...,
id: Optional[int] = ..., id: Optional[int] = ...,
) -> SelectCallbackDecorator[V, UserSelectT]: ) -> SelectCallbackDecorator[UserSelectT]:
... ...
@ -972,7 +972,7 @@ def select(
default_values: Sequence[ValidDefaultValues] = ..., default_values: Sequence[ValidDefaultValues] = ...,
row: Optional[int] = ..., row: Optional[int] = ...,
id: Optional[int] = ..., id: Optional[int] = ...,
) -> SelectCallbackDecorator[V, RoleSelectT]: ) -> SelectCallbackDecorator[RoleSelectT]:
... ...
@ -990,7 +990,7 @@ def select(
default_values: Sequence[ValidDefaultValues] = ..., default_values: Sequence[ValidDefaultValues] = ...,
row: Optional[int] = ..., row: Optional[int] = ...,
id: Optional[int] = ..., id: Optional[int] = ...,
) -> SelectCallbackDecorator[V, ChannelSelectT]: ) -> SelectCallbackDecorator[ChannelSelectT]:
... ...
@ -1008,7 +1008,7 @@ def select(
default_values: Sequence[ValidDefaultValues] = ..., default_values: Sequence[ValidDefaultValues] = ...,
row: Optional[int] = ..., row: Optional[int] = ...,
id: Optional[int] = ..., id: Optional[int] = ...,
) -> SelectCallbackDecorator[V, MentionableSelectT]: ) -> SelectCallbackDecorator[MentionableSelectT]:
... ...
@ -1025,7 +1025,7 @@ def select(
default_values: Sequence[ValidDefaultValues] = MISSING, default_values: Sequence[ValidDefaultValues] = MISSING,
row: Optional[int] = None, row: Optional[int] = None,
id: Optional[int] = None, id: Optional[int] = None,
) -> SelectCallbackDecorator[V, BaseSelectT]: ) -> SelectCallbackDecorator[BaseSelectT]:
"""A decorator that attaches a select menu to a component. """A decorator that attaches a select menu to a component.
The function being decorated should have three parameters, ``self`` representing The function being decorated should have three parameters, ``self`` representing
@ -1110,7 +1110,7 @@ def select(
.. versionadded:: 2.6 .. versionadded:: 2.6
""" """
def decorator(func: ItemCallbackType[V, BaseSelectT]) -> ItemCallbackType[V, BaseSelectT]: def decorator(func: ItemCallbackType[BaseSelectT]) -> ItemCallbackType[BaseSelectT]:
if not inspect.iscoroutinefunction(func): if not inspect.iscoroutinefunction(func):
raise TypeError('select function must be a coroutine function') raise TypeError('select function must be a coroutine function')
callback_cls = getattr(cls, '__origin__', cls) callback_cls = getattr(cls, '__origin__', cls)

12
discord/ui/view.py

@ -84,7 +84,7 @@ if TYPE_CHECKING:
from ..state import ConnectionState from ..state import ConnectionState
from .modal import Modal from .modal import Modal
ItemLike = Union[ItemCallbackType[Any, Any], Item[Any]] ItemLike = Union[ItemCallbackType[Any], Item[Any]]
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -185,8 +185,8 @@ class _ViewWeights:
class _ViewCallback: class _ViewCallback:
__slots__ = ('view', 'callback', 'item') __slots__ = ('view', 'callback', 'item')
def __init__(self, callback: ItemCallbackType[Any, Any], view: BaseView, item: Item[BaseView]) -> None: def __init__(self, callback: ItemCallbackType[Any], view: BaseView, item: Item[BaseView]) -> None:
self.callback: ItemCallbackType[Any, Any] = callback self.callback: ItemCallbackType[Any] = callback
self.view: BaseView = view self.view: BaseView = view
self.item: Item[BaseView] = item self.item: Item[BaseView] = item
@ -452,7 +452,7 @@ class BaseView:
try: try:
item._refresh_state(interaction, interaction.data) # type: ignore item._refresh_state(interaction, interaction.data) # type: ignore
allow = await item.interaction_check(interaction) and await self.interaction_check(interaction) allow = await item._run_checks(interaction) and await self.interaction_check(interaction)
if not allow: if not allow:
return return
@ -587,7 +587,7 @@ class View(BaseView):
) )
super().__init_subclass__() super().__init_subclass__()
children: Dict[str, ItemCallbackType[Any, Any]] = {} children: Dict[str, ItemCallbackType[Any]] = {}
for base in reversed(cls.__mro__): for base in reversed(cls.__mro__):
for name, member in base.__dict__.items(): for name, member in base.__dict__.items():
if hasattr(member, '__discord_ui_model_type__'): if hasattr(member, '__discord_ui_model_type__'):
@ -716,7 +716,7 @@ class LayoutView(BaseView):
def __init_subclass__(cls) -> None: def __init_subclass__(cls) -> None:
children: Dict[str, Item[Any]] = {} children: Dict[str, Item[Any]] = {}
callback_children: Dict[str, ItemCallbackType[Any, Any]] = {} callback_children: Dict[str, ItemCallbackType[Any]] = {}
row = 0 row = 0

Loading…
Cancel
Save