From 9db0dadc423b23819e0d5ca54b3135a08d4c1a56 Mon Sep 17 00:00:00 2001
From: Imayhaveborkedit <imayhaveborkedit@users.noreply.github.com>
Date: Thu, 14 Dec 2023 19:09:04 -0500
Subject: [PATCH] Fix voice disconnect+connect race condition

Fixes a race condition when disconnecting and immediately connecting
again.  Also fixes disconnect() being called twice.

Let me be clear, I DO NOT LIKE THIS SOLUTION.  I think it's dumb but I
don't see any other reasonable alternative.  There isn't a way to
transfer state to a new connection state object and I can't think of a
nice way to do it either.  That said, waiting an arbitrary amount of
time for an arbitrary websocket event doesn't seem like the right
solution either, but it's the best I can do at this point.
---
 discord/voice_state.py | 34 +++++++++++++++++++++++++++++-----
 1 file changed, 29 insertions(+), 5 deletions(-)

diff --git a/discord/voice_state.py b/discord/voice_state.py
index d24f7ea4e..6a680a106 100644
--- a/discord/voice_state.py
+++ b/discord/voice_state.py
@@ -212,6 +212,7 @@ class VoiceConnectionState:
         self._expecting_disconnect: bool = False
         self._connected = threading.Event()
         self._state_event = asyncio.Event()
+        self._disconnected = asyncio.Event()
         self._runner: Optional[asyncio.Task] = None
         self._connector: Optional[asyncio.Task] = None
         self._socket_reader = SocketReader(self)
@@ -254,8 +255,10 @@ class VoiceConnectionState:
         channel_id = data['channel_id']
 
         if channel_id is None:
+            self._disconnected.set()
+
             # If we know we're going to get a voice_state_update where we have no channel due to
-            # being in the reconnect flow, we ignore it.  Otherwise, it probably wasn't from us.
+            # being in the reconnect or disconnect flow, we ignore it.  Otherwise, it probably wasn't from us.
             if self._expecting_disconnect:
                 self._expecting_disconnect = False
             else:
@@ -419,9 +422,9 @@ class VoiceConnectionState:
             return
 
         try:
+            await self._voice_disconnect()
             if self.ws:
                 await self.ws.close()
-            await self._voice_disconnect()
         except Exception:
             _log.debug('Ignoring exception disconnecting from voice', exc_info=True)
         finally:
@@ -436,11 +439,25 @@ class VoiceConnectionState:
 
             if cleanup:
                 self._socket_reader.stop()
-                self.voice_client.cleanup()
 
             if self.socket:
                 self.socket.close()
 
+            # Skip this part if disconnect was called from the poll loop task
+            if self._runner and asyncio.current_task() != self._runner:
+                # Wait for the voice_state_update event confirming the bot left the voice channel.
+                # This prevents a race condition caused by disconnecting and immediately connecting again.
+                # The new VoiceConnectionState object receives the voice_state_update event containing channel=None while still
+                # connecting leaving it in a bad state.  Since there's no nice way to transfer state to the new one, we have to do this.
+                try:
+                    async with atimeout(self.timeout):
+                        await self._disconnected.wait()
+                except TimeoutError:
+                    _log.debug('Timed out waiting for disconnect confirmation event')
+
+            if cleanup:
+                self.voice_client.cleanup()
+
     async def soft_disconnect(self, *, with_state: ConnectionFlowState = ConnectionFlowState.got_both_voice_updates) -> None:
         _log.debug('Soft disconnecting from voice')
         # Stop the websocket reader because closing the websocket will trigger an unwanted reconnect
@@ -524,6 +541,7 @@ class VoiceConnectionState:
         self.state = ConnectionFlowState.disconnected
         await self.voice_client.channel.guild.change_voice_state(channel=None)
         self._expecting_disconnect = True
+        self._disconnected.clear()
 
     async def _connect_websocket(self, resume: bool) -> DiscordVoiceWebSocket:
         ws = await DiscordVoiceWebSocket.from_connection_state(self, resume=resume, hook=self.hook)
@@ -557,8 +575,10 @@ class VoiceConnectionState:
                     # 4014 - we were externally disconnected (voice channel deleted, we were moved, etc)
                     # 4015 - voice server has crashed
                     if exc.code in (1000, 4015):
-                        _log.info('Disconnecting from voice normally, close code %d.', exc.code)
-                        await self.disconnect()
+                        # Don't call disconnect a second time if the websocket closed from a disconnect call
+                        if not self._expecting_disconnect:
+                            _log.info('Disconnecting from voice normally, close code %d.', exc.code)
+                            await self.disconnect()
                         break
 
                     if exc.code == 4014:
@@ -602,6 +622,8 @@ class VoiceConnectionState:
             )
         except asyncio.TimeoutError:
             return False
+
+        previous_ws = self.ws
         try:
             self.ws = await self._connect_websocket(False)
             await self._handshake_websocket()
@@ -609,6 +631,8 @@ class VoiceConnectionState:
             return False
         else:
             return True
+        finally:
+            await previous_ws.close()
 
     async def _move_to(self, channel: abc.Snowflake) -> None:
         await self.voice_client.channel.guild.change_voice_state(channel=channel)