diff --git a/discord/ui/view.py b/discord/ui/view.py index e6f8df348..27c765947 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -162,14 +162,32 @@ class View: self.__weights = _ViewWeights(self.children) loop = asyncio.get_running_loop() - self.id = os.urandom(16).hex() - self._cancel_callback: Optional[Callable[[View], None]] = None - self._timeout_handler: Optional[asyncio.TimerHandle] = None - self._stopped = loop.create_future() + self.id: str = os.urandom(16).hex() + self.__cancel_callback: Optional[Callable[[View], None]] = None + self.__timeout_expiry: Optional[float] = None + self.__timeout_task: Optional[asyncio.Task[None]] = None + self.__stopped: asyncio.Future[bool] = loop.create_future() def __repr__(self) -> str: return f'<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>' + async def __timeout_task_impl(self) -> None: + while True: + # Guard just in case someone changes the value of the timeout at runtime + if self.timeout is None: + return + + if self.__timeout_expiry is None: + return self._dispatch_timeout() + + # Check if we've elapsed our currently set timeout + now = time.monotonic() + if now >= self.__timeout_expiry: + return self._dispatch_timeout() + + # Wait N seconds to see if timeout data has been refreshed + await asyncio.sleep(self.__timeout_expiry - now) + def to_components(self) -> List[Dict[str, Any]]: def key(item: Item) -> int: return item._rendered_row or 0 @@ -328,8 +346,11 @@ class View: print(f'Ignoring exception in view {self} for item {item}:', file=sys.stderr) traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr) - async def _scheduled_task(self, state: Any, item: Item, interaction: Interaction): + async def _scheduled_task(self, item: Item, interaction: Interaction): try: + if self.timeout: + self.__timeout_expiry = time.monotonic() + self.timeout + allow = await self.interaction_check(interaction) if not allow: return @@ -340,21 +361,28 @@ class View: except Exception as e: return await self.on_error(e, item, interaction) - def _start_listening(self, store: ViewStore) -> None: - self._cancel_callback = partial(store.remove_view) + def _start_listening_from_store(self, store: ViewStore) -> None: + self.__cancel_callback = partial(store.remove_view) if self.timeout: loop = asyncio.get_running_loop() - self._timeout_handler = loop.call_later(self.timeout, self.dispatch_timeout) + if self.__timeout_task is not None: + self.__timeout_task.cancel() + + self.__timeout_expiry = time.monotonic() + self.timeout + self.__timeout_task = loop.create_task(self.__timeout_task_impl()) - def dispatch_timeout(self): - if self._stopped.done(): + def _dispatch_timeout(self): + if self.__stopped.done(): return - self._stopped.set_result(True) + self.__stopped.set_result(True) asyncio.create_task(self.on_timeout(), name=f'discord-ui-view-timeout-{self.id}') - def dispatch(self, state: Any, item: Item, interaction: Interaction): - asyncio.create_task(self._scheduled_task(state, item, interaction), name=f'discord-ui-view-dispatch-{self.id}') + def _dispatch_item(self, item: Item, interaction: Interaction): + if self.__stopped.done(): + return + + asyncio.create_task(self._scheduled_task(item, interaction), name=f'discord-ui-view-dispatch-{self.id}') def refresh(self, components: List[Component]): # This is pretty hacky at the moment @@ -382,23 +410,25 @@ class View: This operation cannot be undone. """ - if not self._stopped.done(): - self._stopped.set_result(False) + if not self.__stopped.done(): + self.__stopped.set_result(False) - if self._timeout_handler: - self._timeout_handler.cancel() + self.__timeout_expiry = None + if self.__timeout_task is not None: + self.__timeout_task.cancel() + self.__timeout_task = None - if self._cancel_callback: - self._cancel_callback(self) - self._cancel_callback = None + if self.__cancel_callback: + self.__cancel_callback(self) + self.__cancel_callback = None def is_finished(self) -> bool: """:class:`bool`: Whether the view has finished interacting.""" - return self._stopped.done() + return self.__stopped.done() def is_dispatching(self) -> bool: """:class:`bool`: Whether the view has been added for dispatching purposes.""" - return self._cancel_callback is not None + return self.__cancel_callback is not None def is_persistent(self) -> bool: """:class:`bool`: Whether the view is set up as persistent. @@ -420,13 +450,13 @@ class View: If ``True``, then the view timed out. If ``False`` then the view finished normally. """ - return await self._stopped + return await self.__stopped class ViewStore: def __init__(self, state: ConnectionState): - # (component_type, custom_id): (View, Item, Expiry) - self._views: Dict[Tuple[int, str], Tuple[View, Item, Optional[float]]] = {} + # (component_type, custom_id): (View, Item) + self._views: Dict[Tuple[int, str], Tuple[View, Item]] = {} # message_id: View self._synced_message_views: Dict[int, View] = {} self._state: ConnectionState = state @@ -436,7 +466,7 @@ class ViewStore: # fmt: off views = { view.id: view - for (_, (view, _, _)) in self._views.items() + for (_, (view, _)) in self._views.items() if view.is_persistent() } # fmt: on @@ -445,8 +475,8 @@ class ViewStore: def __verify_integrity(self): to_remove: List[Tuple[int, str]] = [] now = time.monotonic() - for (k, (_, _, expiry)) in self._views.items(): - if expiry is not None and now >= expiry: + for (k, (view, _)) in self._views.items(): + if view.is_finished(): to_remove.append(k) for k in to_remove: @@ -455,11 +485,10 @@ class ViewStore: def add_view(self, view: View, message_id: Optional[int] = None): self.__verify_integrity() - expiry = view._expires_at - view._start_listening(self) + view._start_listening_from_store(self) for item in view.children: if item.is_dispatchable(): - self._views[(item.type.value, item.custom_id)] = (view, item, expiry) # type: ignore + self._views[(item.type.value, item.custom_id)] = (view, item) # type: ignore if message_id is not None: self._synced_message_views[message_id] = view @@ -481,10 +510,10 @@ class ViewStore: if value is None: return - view, item, _ = value - self._views[key] = (view, item, view._expires_at) + view, item = value + self._views[key] = (view, item) item.refresh_state(interaction) - view.dispatch(self._state, item, interaction) + view._dispatch_item(item, interaction) def is_message_tracked(self, message_id: int): return message_id in self._synced_message_views