diff --git a/discord/ui/view.py b/discord/ui/view.py index b5cf2f0d0..1e79625c3 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -47,6 +47,7 @@ __all__ = ( if TYPE_CHECKING: from ..interactions import Interaction from ..types.components import Component as ComponentPayload + from ..state import ConnectionState def _walk_all_components(components: List[Component]) -> Iterator[Component]: @@ -118,6 +119,7 @@ class View: self.id = os.urandom(16).hex() self._cancel_callback: Optional[Callable[[View], None]] = None + self._timeout_handler: Optional[asyncio.TimerHandle] = None self._stopped = asyncio.Event() def to_components(self) -> List[Dict[str, Any]]: @@ -225,6 +227,13 @@ class View: """ return True + async def on_timeout(self) -> None: + """|coro| + + A callback that is called when a view's timeout elapses without being explicitly stopped. + """ + pass + async def _scheduled_task(self, state: Any, item: Item, interaction: Interaction): try: allow = await self.interaction_check(interaction) @@ -238,6 +247,15 @@ class View: if not interaction.response._responded: await interaction.response.defer() + def _start_listening(self, store: ViewStore) -> None: + self._cancel_callback = partial(store.remove_view) + if self.timeout: + loop = asyncio.get_running_loop() + self._timeout_handler = loop.call_later(self.timeout, self.dispatch_timeout) + + def dispatch_timeout(self): + asyncio.create_task(self.on_timeout(), name=f'discord-ui-view-timeout-{self.id}') + def dispatch(self, state: Any, item: Item, interaction: Interaction): asyncio.create_task(self._scheduled_task(state, item, interaction), name=f'discord-ui-view-dispatch-{self.id}') @@ -268,6 +286,9 @@ class View: This operation cannot be undone. """ self._stopped.set() + if self._timeout_handler: + self._timeout_handler.cancel() + if self._cancel_callback: self._cancel_callback(self) @@ -280,12 +301,12 @@ class View: class ViewStore: - def __init__(self, state): + def __init__(self, state: ConnectionState): # (component_type, custom_id): (View, Item, Expiry) self._views: Dict[Tuple[int, str], Tuple[View, Item, Optional[float]]] = {} # message_id: View self._synced_message_views: Dict[int, View] = {} - self._state = state + self._state: ConnectionState = state def __verify_integrity(self): to_remove: List[Tuple[int, str]] = [] @@ -301,7 +322,7 @@ class ViewStore: self.__verify_integrity() expiry = view._expires_at - view._cancel_callback = partial(self.remove_view) + view._start_listening(self) for item in view.children: if item.is_dispatchable(): self._views[(item.type.value, item.custom_id)] = (view, item, expiry) # type: ignore