Browse Source

Change the way callbacks are defined to allow deriving

This should hopefully make these work more consistently as other
functions do.
pull/6961/head
Rapptz 4 years ago
parent
commit
4c0ebc9221
  1. 33
      discord/ui/button.py
  2. 82
      discord/ui/item.py
  3. 19
      discord/ui/view.py

33
discord/ui/button.py

@ -87,8 +87,6 @@ class Button(Item):
The emoji of the button, if available.
"""
__slots__: Tuple[str, ...] = Item.__slots__ + ('_underlying',)
__item_repr_attributes__: Tuple[str, ...] = (
'style',
'url',
@ -192,19 +190,6 @@ class Button(Item):
else:
self._underlying.emoji = None
def copy(self: B) -> B:
button = self.__class__(
style=self.style,
label=self.label,
disabled=self.disabled,
custom_id=self.custom_id,
url=self.url,
emoji=self.emoji,
group=self.group_id,
)
button.callback = self.callback
return button
@classmethod
def from_component(cls: Type[B], button: ButtonComponent) -> B:
return cls(
@ -239,7 +224,7 @@ def button(
style: ButtonStyle = ButtonStyle.grey,
emoji: Optional[Union[str, PartialEmoji]] = None,
group: Optional[int] = None,
) -> Callable[[ItemCallbackType], Button]:
) -> Callable[[ItemCallbackType], ItemCallbackType]:
"""A decorator that attaches a button to a component.
The function being decorated should have three parameters, ``self`` representing
@ -275,14 +260,22 @@ def button(
ordering.
"""
def decorator(func: ItemCallbackType) -> Button:
def decorator(func: ItemCallbackType) -> ItemCallbackType:
nonlocal custom_id
if not inspect.iscoroutinefunction(func):
raise TypeError('button function must be a coroutine function')
custom_id = custom_id or os.urandom(32).hex()
button = Button(style=style, custom_id=custom_id, url=None, disabled=disabled, label=label, emoji=emoji, group=group)
button.callback = func
return button
func.__discord_ui_model_type__ = Button
func.__discord_ui_model_kwargs__ = {
'style': style,
'custom_id': custom_id,
'url': None,
'disabled': disabled,
'label': label,
'emoji': emoji,
'group': group,
}
return func
return decorator

82
discord/ui/item.py

@ -24,8 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, Callable, Coroutine, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
import inspect
from typing import Any, Callable, Coroutine, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
from ..interactions import Interaction
@ -50,25 +49,15 @@ class Item:
- :class:`discord.ui.Button`
"""
__slots__: Tuple[str, ...] = (
'_callback',
'_pass_view_arg',
'group_id',
)
__item_repr_attributes__: Tuple[str, ...] = ('group_id',)
def __init__(self):
self._callback: Optional[ItemCallbackType] = None
self._pass_view_arg = True
self._view: Optional[View] = None
self.group_id: Optional[int] = None
def to_component_dict(self) -> Dict[str, Any]:
raise NotImplementedError
def copy(self: I) -> I:
raise NotImplementedError
def refresh_state(self, component: Component) -> None:
return None
@ -88,53 +77,20 @@ class Item:
return f'<{self.__class__.__name__} {attrs}>'
@property
def callback(self) -> Optional[ItemCallbackType]:
"""Returns the underlying callback associated with this interaction."""
return self._callback
@callback.setter
def callback(self, value: Optional[ItemCallbackType]):
if value is None:
self._callback = None
return
# Check if it's a partial function
try:
partial = value.func
except AttributeError:
pass
else:
if not inspect.iscoroutinefunction(value.func):
raise TypeError(f'inner partial function must be a coroutine')
# Check if the partial is bound
try:
bound_partial = partial.__self__
except AttributeError:
pass
else:
self._pass_view_arg = not hasattr(bound_partial, '__discord_ui_view__')
self._callback = value
return
try:
func_self = value.__self__
except AttributeError:
pass
else:
if not isinstance(func_self, Item):
raise TypeError(f'callback bound method must be from Item not {func_self!r}')
else:
value = value.__func__
if not inspect.iscoroutinefunction(value):
raise TypeError(f'callback must be a coroutine not {value!r}')
self._callback = value
async def _do_call(self, view: View, interaction: Interaction):
if self._pass_view_arg:
await self._callback(view, self, interaction)
else:
await self._callback(self, interaction) # type: ignore
def view(self) -> Optional[View]:
"""Optional[:class:`View`]: The underlying view for this item."""
return self._view
async def callback(self, interaction: Interaction):
"""|coro|
The callback associated with this UI item.
This can be overriden by subclasses.
Parameters
-----------
interaction: :class:`Interaction`
The interaction that triggered this UI item.
"""
pass

19
discord/ui/view.py

@ -31,7 +31,7 @@ import asyncio
import sys
import time
import os
from .item import Item
from .item import Item, ItemCallbackType
from ..enums import ComponentType
from ..components import (
Component,
@ -95,13 +95,13 @@ class View:
__discord_ui_view__: ClassVar[bool] = True
if TYPE_CHECKING:
__view_children_items__: ClassVar[List[Item]]
__view_children_items__: ClassVar[List[ItemCallbackType]]
def __init_subclass__(cls) -> None:
children: List[Item] = []
children: List[ItemCallbackType] = []
for base in reversed(cls.__mro__):
for member in base.__dict__.values():
if isinstance(member, Item):
if hasattr(member, '__discord_ui_model_type__'):
children.append(member)
if len(children) > 25:
@ -111,7 +111,13 @@ class View:
def __init__(self, timeout: Optional[float] = 180.0):
self.timeout = timeout
self.children: List[Item] = [i.copy() for i in self.__view_children_items__]
self.children: List[Item] = []
for func in self.__view_children_items__:
item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__)
item.callback = partial(func, self, item)
item._view = self
self.children.append(item)
self.id = os.urandom(16).hex()
self._cancel_callback: Optional[Callable[[View], None]] = None
@ -171,11 +177,12 @@ class View:
if not isinstance(item, Item):
raise TypeError(f'expected Item not {item.__class__!r}')
item._view = self
self.children.append(item)
async def _scheduled_task(self, state: Any, item: Item, interaction: Interaction):
await state.http.create_interaction_response(interaction.id, interaction.token, type=6)
await item._do_call(self, interaction)
await item.callback(interaction)
def dispatch(self, state: Any, item: Item, interaction: Interaction):
asyncio.create_task(self._scheduled_task(state, item, interaction), name=f'discord-ui-view-dispatch-{self.id}')

Loading…
Cancel
Save