@ -81,9 +81,10 @@ _log = logging.getLogger(__name__)
class SocketReader ( threading . Thread ) :
def __init__ ( self , state : VoiceConnectionState ) - > None :
def __init__ ( self , state : VoiceConnectionState , * , start_paused : bool = True ) - > None :
super ( ) . __init__ ( daemon = True , name = f ' voice-socket-reader: { id ( self ) : #x } ' )
self . state : VoiceConnectionState = state
self . start_paused = start_paused
self . _callbacks : List [ SocketReaderCallback ] = [ ]
self . _running = threading . Event ( )
self . _end = threading . Event ( )
@ -130,6 +131,8 @@ class SocketReader(threading.Thread):
def run ( self ) - > None :
self . _end . clear ( )
self . _running . set ( )
if self . start_paused :
self . pause ( )
try :
self . _do_run ( )
except Exception :
@ -148,7 +151,10 @@ class SocketReader(threading.Thread):
# Since this socket is a non blocking socket, select has to be used to wait on it for reading.
try :
readable , _ , _ = select . select ( [ self . state . socket ] , [ ] , [ ] , 30 )
except ( ValueError , TypeError ) :
except ( ValueError , TypeError , OSError ) as e :
_log . debug (
" Select error handling socket in reader, this should be safe to ignore: %s : %s " , e . __class__ . __name__ , e
)
# The socket is either closed or doesn't exist at the moment
continue
@ -305,6 +311,10 @@ class VoiceConnectionState:
_log . debug ( ' Ignoring unexpected voice_state_update event ' )
async def voice_server_update ( self , data : VoiceServerUpdatePayload ) - > None :
previous_token = self . token
previous_server_id = self . server_id
previous_endpoint = self . endpoint
self . token = data [ ' token ' ]
self . server_id = int ( data [ ' guild_id ' ] )
endpoint = data . get ( ' endpoint ' )
@ -338,6 +348,10 @@ class VoiceConnectionState:
self . state = ConnectionFlowState . got_voice_server_update
elif self . state is not ConnectionFlowState . disconnected :
# eventual consistency
if previous_token == self . token and previous_server_id == self . server_id and previous_token == self . token :
return
_log . debug ( ' Unexpected server update event, attempting to handle ' )
await self . soft_disconnect ( with_state = ConnectionFlowState . got_voice_server_update )
@ -422,7 +436,7 @@ class VoiceConnectionState:
if not self . _runner :
self . _runner = self . voice_client . loop . create_task ( self . _poll_voice_ws ( reconnect ) , name = ' Voice websocket poller ' )
async def disconnect ( self , * , force : bool = True , cleanup : bool = True ) - > None :
async def disconnect ( self , * , force : bool = True , cleanup : bool = True , wait : bool = False ) - > None :
if not force and not self . is_connected ( ) :
return
@ -433,23 +447,26 @@ class VoiceConnectionState:
except Exception :
_log . debug ( ' Ignoring exception disconnecting from voice ' , exc_info = True )
finally :
self . ip = MISSING
self . port = MISSING
self . state = ConnectionFlowState . disconnected
self . _socket_reader . pause ( )
# Stop threads before we unlock waiters so they end properly
if cleanup :
self . _socket_reader . stop ( )
self . voice_client . stop ( )
# Flip the connected event to unlock any waiters
self . _connected . set ( )
self . _connected . clear ( )
if cleanup :
self . _socket_reader . stop ( )
if self . socket :
self . socket . close ( )
self . ip = MISSING
self . port = MISSING
# Skip this part if disconnect was called from the poll loop task
if self . _runner and asyncio . current_task ( ) != self . _runner :
if wait and not self . _inside_runner ( ) :
# Wait for the voice_state_update event confirming the bot left the voice channel.
# This prevents a race condition caused by disconnecting and immediately connecting again.
# The new VoiceConnectionState object receives the voice_state_update event containing channel=None while still
@ -458,7 +475,9 @@ class VoiceConnectionState:
async with atimeout ( self . timeout ) :
await self . _disconnected . wait ( )
except TimeoutError :
_log . debug ( ' Timed out waiting for disconnect confirmation event ' )
_log . debug ( ' Timed out waiting for voice disconnection confirmation ' )
except asyncio . CancelledError :
pass
if cleanup :
self . voice_client . cleanup ( )
@ -476,23 +495,26 @@ class VoiceConnectionState:
except Exception :
_log . debug ( ' Ignoring exception soft disconnecting from voice ' , exc_info = True )
finally :
self . ip = MISSING
self . port = MISSING
self . state = with_state
self . _socket_reader . pause ( )
if self . socket :
self . socket . close ( )
self . ip = MISSING
self . port = MISSING
async def move_to ( self , channel : Optional [ abc . Snowflake ] , timeout : Optional [ float ] ) - > None :
if channel is None :
await self . disconnect ( )
# This function should only be called externally so its ok to wait for the disconnect.
await self . disconnect ( wait = True )
return
if self . voice_client . channel and channel . id == self . voice_client . channel . id :
return
previous_state = self . state
# this is only an outgoing ws request
# if it fails, nothing happens and nothing changes (besides self.state)
await self . _move_to ( channel )
@ -504,7 +526,6 @@ class VoiceConnectionState:
_log . warning ( ' Timed out trying to move to channel %s in guild %s ' , channel . id , self . guild . id )
if self . state is last_state :
_log . debug ( ' Reverting to previous state %s ' , previous_state . name )
self . state = previous_state
def wait ( self , timeout : Optional [ float ] = None ) - > bool :
@ -527,6 +548,9 @@ class VoiceConnectionState:
_log . debug ( ' Unregistering socket listener callback %s ' , callback )
self . _socket_reader . unregister ( callback )
def _inside_runner ( self ) - > bool :
return self . _runner is not None and asyncio . current_task ( ) == self . _runner
async def _wait_for_state (
self , state : ConnectionFlowState , * other_states : ConnectionFlowState , timeout : Optional [ float ] = None
) - > None :
@ -590,10 +614,20 @@ class VoiceConnectionState:
break
if exc . code == 4014 :
# We were disconnected by discord
# This condition is a race between the main ws event and the voice ws closing
if self . _disconnected . is_set ( ) :
_log . info ( ' Disconnected from voice by discord, close code %d . ' , exc . code )
await self . disconnect ( )
break
# We may have been moved to a different channel
_log . info ( ' Disconnected from voice by force... potentially reconnecting. ' )
successful = await self . _potential_reconnect ( )
if not successful :
_log . info ( ' Reconnect was unsuccessful, disconnecting from voice normally... ' )
# Don't bother to disconnect if already disconnected
if self . state is not ConnectionFlowState . disconnected :
await self . disconnect ( )
break
else :
@ -626,10 +660,16 @@ class VoiceConnectionState:
async def _potential_reconnect ( self ) - > bool :
try :
await self . _wait_for_state (
ConnectionFlowState . got_voice_server_update , ConnectionFlowState . got_both_voice_updates , timeout = self . timeout
ConnectionFlowState . got_voice_server_update ,
ConnectionFlowState . got_both_voice_updates ,
ConnectionFlowState . disconnected ,
timeout = self . timeout ,
)
except asyncio . TimeoutError :
return False
else :
if self . state is ConnectionFlowState . disconnected :
return False
previous_ws = self . ws
try :