Browse Source

Change the way callbacks are defined to allow deriving

This should hopefully make these work more consistently as other
functions do.
pull/6961/head
Rapptz 4 years ago
parent
commit
4c0ebc9221
  1. 33
      discord/ui/button.py
  2. 82
      discord/ui/item.py
  3. 19
      discord/ui/view.py

33
discord/ui/button.py

@ -87,8 +87,6 @@ class Button(Item):
The emoji of the button, if available. The emoji of the button, if available.
""" """
__slots__: Tuple[str, ...] = Item.__slots__ + ('_underlying',)
__item_repr_attributes__: Tuple[str, ...] = ( __item_repr_attributes__: Tuple[str, ...] = (
'style', 'style',
'url', 'url',
@ -192,19 +190,6 @@ class Button(Item):
else: else:
self._underlying.emoji = None self._underlying.emoji = None
def copy(self: B) -> B:
button = self.__class__(
style=self.style,
label=self.label,
disabled=self.disabled,
custom_id=self.custom_id,
url=self.url,
emoji=self.emoji,
group=self.group_id,
)
button.callback = self.callback
return button
@classmethod @classmethod
def from_component(cls: Type[B], button: ButtonComponent) -> B: def from_component(cls: Type[B], button: ButtonComponent) -> B:
return cls( return cls(
@ -239,7 +224,7 @@ def button(
style: ButtonStyle = ButtonStyle.grey, style: ButtonStyle = ButtonStyle.grey,
emoji: Optional[Union[str, PartialEmoji]] = None, emoji: Optional[Union[str, PartialEmoji]] = None,
group: Optional[int] = None, group: Optional[int] = None,
) -> Callable[[ItemCallbackType], Button]: ) -> Callable[[ItemCallbackType], ItemCallbackType]:
"""A decorator that attaches a button to a component. """A decorator that attaches a button to a component.
The function being decorated should have three parameters, ``self`` representing The function being decorated should have three parameters, ``self`` representing
@ -275,14 +260,22 @@ def button(
ordering. ordering.
""" """
def decorator(func: ItemCallbackType) -> Button: def decorator(func: ItemCallbackType) -> ItemCallbackType:
nonlocal custom_id nonlocal custom_id
if not inspect.iscoroutinefunction(func): if not inspect.iscoroutinefunction(func):
raise TypeError('button function must be a coroutine function') raise TypeError('button function must be a coroutine function')
custom_id = custom_id or os.urandom(32).hex() custom_id = custom_id or os.urandom(32).hex()
button = Button(style=style, custom_id=custom_id, url=None, disabled=disabled, label=label, emoji=emoji, group=group) func.__discord_ui_model_type__ = Button
button.callback = func func.__discord_ui_model_kwargs__ = {
return button 'style': style,
'custom_id': custom_id,
'url': None,
'disabled': disabled,
'label': label,
'emoji': emoji,
'group': group,
}
return func
return decorator return decorator

82
discord/ui/item.py

@ -24,8 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, Coroutine, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union from typing import Any, Callable, Coroutine, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
import inspect
from ..interactions import Interaction from ..interactions import Interaction
@ -50,25 +49,15 @@ class Item:
- :class:`discord.ui.Button` - :class:`discord.ui.Button`
""" """
__slots__: Tuple[str, ...] = (
'_callback',
'_pass_view_arg',
'group_id',
)
__item_repr_attributes__: Tuple[str, ...] = ('group_id',) __item_repr_attributes__: Tuple[str, ...] = ('group_id',)
def __init__(self): def __init__(self):
self._callback: Optional[ItemCallbackType] = None self._view: Optional[View] = None
self._pass_view_arg = True
self.group_id: Optional[int] = None self.group_id: Optional[int] = None
def to_component_dict(self) -> Dict[str, Any]: def to_component_dict(self) -> Dict[str, Any]:
raise NotImplementedError raise NotImplementedError
def copy(self: I) -> I:
raise NotImplementedError
def refresh_state(self, component: Component) -> None: def refresh_state(self, component: Component) -> None:
return None return None
@ -88,53 +77,20 @@ class Item:
return f'<{self.__class__.__name__} {attrs}>' return f'<{self.__class__.__name__} {attrs}>'
@property @property
def callback(self) -> Optional[ItemCallbackType]: def view(self) -> Optional[View]:
"""Returns the underlying callback associated with this interaction.""" """Optional[:class:`View`]: The underlying view for this item."""
return self._callback return self._view
@callback.setter async def callback(self, interaction: Interaction):
def callback(self, value: Optional[ItemCallbackType]): """|coro|
if value is None:
self._callback = None The callback associated with this UI item.
return
This can be overriden by subclasses.
# Check if it's a partial function
try: Parameters
partial = value.func -----------
except AttributeError: interaction: :class:`Interaction`
pass The interaction that triggered this UI item.
else: """
if not inspect.iscoroutinefunction(value.func): pass
raise TypeError(f'inner partial function must be a coroutine')
# Check if the partial is bound
try:
bound_partial = partial.__self__
except AttributeError:
pass
else:
self._pass_view_arg = not hasattr(bound_partial, '__discord_ui_view__')
self._callback = value
return
try:
func_self = value.__self__
except AttributeError:
pass
else:
if not isinstance(func_self, Item):
raise TypeError(f'callback bound method must be from Item not {func_self!r}')
else:
value = value.__func__
if not inspect.iscoroutinefunction(value):
raise TypeError(f'callback must be a coroutine not {value!r}')
self._callback = value
async def _do_call(self, view: View, interaction: Interaction):
if self._pass_view_arg:
await self._callback(view, self, interaction)
else:
await self._callback(self, interaction) # type: ignore

19
discord/ui/view.py

@ -31,7 +31,7 @@ import asyncio
import sys import sys
import time import time
import os import os
from .item import Item from .item import Item, ItemCallbackType
from ..enums import ComponentType from ..enums import ComponentType
from ..components import ( from ..components import (
Component, Component,
@ -95,13 +95,13 @@ class View:
__discord_ui_view__: ClassVar[bool] = True __discord_ui_view__: ClassVar[bool] = True
if TYPE_CHECKING: if TYPE_CHECKING:
__view_children_items__: ClassVar[List[Item]] __view_children_items__: ClassVar[List[ItemCallbackType]]
def __init_subclass__(cls) -> None: def __init_subclass__(cls) -> None:
children: List[Item] = [] children: List[ItemCallbackType] = []
for base in reversed(cls.__mro__): for base in reversed(cls.__mro__):
for member in base.__dict__.values(): for member in base.__dict__.values():
if isinstance(member, Item): if hasattr(member, '__discord_ui_model_type__'):
children.append(member) children.append(member)
if len(children) > 25: if len(children) > 25:
@ -111,7 +111,13 @@ class View:
def __init__(self, timeout: Optional[float] = 180.0): def __init__(self, timeout: Optional[float] = 180.0):
self.timeout = timeout self.timeout = timeout
self.children: List[Item] = [i.copy() for i in self.__view_children_items__] self.children: List[Item] = []
for func in self.__view_children_items__:
item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__)
item.callback = partial(func, self, item)
item._view = self
self.children.append(item)
self.id = os.urandom(16).hex() self.id = os.urandom(16).hex()
self._cancel_callback: Optional[Callable[[View], None]] = None self._cancel_callback: Optional[Callable[[View], None]] = None
@ -171,11 +177,12 @@ class View:
if not isinstance(item, Item): if not isinstance(item, Item):
raise TypeError(f'expected Item not {item.__class__!r}') raise TypeError(f'expected Item not {item.__class__!r}')
item._view = self
self.children.append(item) self.children.append(item)
async def _scheduled_task(self, state: Any, item: Item, interaction: Interaction): async def _scheduled_task(self, state: Any, item: Item, interaction: Interaction):
await state.http.create_interaction_response(interaction.id, interaction.token, type=6) await state.http.create_interaction_response(interaction.id, interaction.token, type=6)
await item._do_call(self, interaction) await item.callback(interaction)
def dispatch(self, state: Any, item: Item, interaction: Interaction): def dispatch(self, state: Any, item: Item, interaction: Interaction):
asyncio.create_task(self._scheduled_task(state, item, interaction), name=f'discord-ui-view-dispatch-{self.id}') asyncio.create_task(self._scheduled_task(state, item, interaction), name=f'discord-ui-view-dispatch-{self.id}')

Loading…
Cancel
Save