From 4c0ebc922155c2d9ca3129e0dbfdcea10f3ad777 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Mon, 26 Apr 2021 06:02:43 -0400 Subject: [PATCH] Change the way callbacks are defined to allow deriving This should hopefully make these work more consistently as other functions do. --- discord/ui/button.py | 33 +++++++----------- discord/ui/item.py | 82 ++++++++++---------------------------------- discord/ui/view.py | 19 ++++++---- 3 files changed, 45 insertions(+), 89 deletions(-) diff --git a/discord/ui/button.py b/discord/ui/button.py index afc69f7ab..8ff4ce744 100644 --- a/discord/ui/button.py +++ b/discord/ui/button.py @@ -87,8 +87,6 @@ class Button(Item): The emoji of the button, if available. """ - __slots__: Tuple[str, ...] = Item.__slots__ + ('_underlying',) - __item_repr_attributes__: Tuple[str, ...] = ( 'style', 'url', @@ -192,19 +190,6 @@ class Button(Item): else: self._underlying.emoji = None - def copy(self: B) -> B: - button = self.__class__( - style=self.style, - label=self.label, - disabled=self.disabled, - custom_id=self.custom_id, - url=self.url, - emoji=self.emoji, - group=self.group_id, - ) - button.callback = self.callback - return button - @classmethod def from_component(cls: Type[B], button: ButtonComponent) -> B: return cls( @@ -239,7 +224,7 @@ def button( style: ButtonStyle = ButtonStyle.grey, emoji: Optional[Union[str, PartialEmoji]] = None, group: Optional[int] = None, -) -> Callable[[ItemCallbackType], Button]: +) -> Callable[[ItemCallbackType], ItemCallbackType]: """A decorator that attaches a button to a component. The function being decorated should have three parameters, ``self`` representing @@ -275,14 +260,22 @@ def button( ordering. """ - def decorator(func: ItemCallbackType) -> Button: + def decorator(func: ItemCallbackType) -> ItemCallbackType: nonlocal custom_id if not inspect.iscoroutinefunction(func): raise TypeError('button function must be a coroutine function') custom_id = custom_id or os.urandom(32).hex() - button = Button(style=style, custom_id=custom_id, url=None, disabled=disabled, label=label, emoji=emoji, group=group) - button.callback = func - return button + func.__discord_ui_model_type__ = Button + func.__discord_ui_model_kwargs__ = { + 'style': style, + 'custom_id': custom_id, + 'url': None, + 'disabled': disabled, + 'label': label, + 'emoji': emoji, + 'group': group, + } + return func return decorator diff --git a/discord/ui/item.py b/discord/ui/item.py index 7726407e9..dc6c91a0f 100644 --- a/discord/ui/item.py +++ b/discord/ui/item.py @@ -24,8 +24,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Any, Callable, Coroutine, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union -import inspect +from typing import Any, Callable, Coroutine, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar from ..interactions import Interaction @@ -50,25 +49,15 @@ class Item: - :class:`discord.ui.Button` """ - __slots__: Tuple[str, ...] = ( - '_callback', - '_pass_view_arg', - 'group_id', - ) - __item_repr_attributes__: Tuple[str, ...] = ('group_id',) def __init__(self): - self._callback: Optional[ItemCallbackType] = None - self._pass_view_arg = True + self._view: Optional[View] = None self.group_id: Optional[int] = None def to_component_dict(self) -> Dict[str, Any]: raise NotImplementedError - def copy(self: I) -> I: - raise NotImplementedError - def refresh_state(self, component: Component) -> None: return None @@ -88,53 +77,20 @@ class Item: return f'<{self.__class__.__name__} {attrs}>' @property - def callback(self) -> Optional[ItemCallbackType]: - """Returns the underlying callback associated with this interaction.""" - return self._callback - - @callback.setter - def callback(self, value: Optional[ItemCallbackType]): - if value is None: - self._callback = None - return - - # Check if it's a partial function - try: - partial = value.func - except AttributeError: - pass - else: - if not inspect.iscoroutinefunction(value.func): - raise TypeError(f'inner partial function must be a coroutine') - - # Check if the partial is bound - try: - bound_partial = partial.__self__ - except AttributeError: - pass - else: - self._pass_view_arg = not hasattr(bound_partial, '__discord_ui_view__') - - self._callback = value - return - - try: - func_self = value.__self__ - except AttributeError: - pass - else: - if not isinstance(func_self, Item): - raise TypeError(f'callback bound method must be from Item not {func_self!r}') - else: - value = value.__func__ - - if not inspect.iscoroutinefunction(value): - raise TypeError(f'callback must be a coroutine not {value!r}') - - self._callback = value - - async def _do_call(self, view: View, interaction: Interaction): - if self._pass_view_arg: - await self._callback(view, self, interaction) - else: - await self._callback(self, interaction) # type: ignore + def view(self) -> Optional[View]: + """Optional[:class:`View`]: The underlying view for this item.""" + return self._view + + async def callback(self, interaction: Interaction): + """|coro| + + The callback associated with this UI item. + + This can be overriden by subclasses. + + Parameters + ----------- + interaction: :class:`Interaction` + The interaction that triggered this UI item. + """ + pass diff --git a/discord/ui/view.py b/discord/ui/view.py index 273a45d0b..712f787a1 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -31,7 +31,7 @@ import asyncio import sys import time import os -from .item import Item +from .item import Item, ItemCallbackType from ..enums import ComponentType from ..components import ( Component, @@ -95,13 +95,13 @@ class View: __discord_ui_view__: ClassVar[bool] = True if TYPE_CHECKING: - __view_children_items__: ClassVar[List[Item]] + __view_children_items__: ClassVar[List[ItemCallbackType]] def __init_subclass__(cls) -> None: - children: List[Item] = [] + children: List[ItemCallbackType] = [] for base in reversed(cls.__mro__): for member in base.__dict__.values(): - if isinstance(member, Item): + if hasattr(member, '__discord_ui_model_type__'): children.append(member) if len(children) > 25: @@ -111,7 +111,13 @@ class View: def __init__(self, timeout: Optional[float] = 180.0): self.timeout = timeout - self.children: List[Item] = [i.copy() for i in self.__view_children_items__] + self.children: List[Item] = [] + for func in self.__view_children_items__: + item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__) + item.callback = partial(func, self, item) + item._view = self + self.children.append(item) + self.id = os.urandom(16).hex() self._cancel_callback: Optional[Callable[[View], None]] = None @@ -171,11 +177,12 @@ class View: if not isinstance(item, Item): raise TypeError(f'expected Item not {item.__class__!r}') + item._view = self self.children.append(item) async def _scheduled_task(self, state: Any, item: Item, interaction: Interaction): await state.http.create_interaction_response(interaction.id, interaction.token, type=6) - await item._do_call(self, interaction) + await item.callback(interaction) def dispatch(self, state: Any, item: Item, interaction: Interaction): asyncio.create_task(self._scheduled_task(state, item, interaction), name=f'discord-ui-view-dispatch-{self.id}')