From f9c2ac9d25fe65a99bf92b60002920b20fe55c1d Mon Sep 17 00:00:00 2001 From: Rapptz Date: Tue, 18 Apr 2017 19:05:34 -0400 Subject: [PATCH] Better handling of VOICE_SERVER_UPDATE. This now sort of respects "Awaiting Endpoint..." waiting. I haven't actually tested out this case since it's hard to get it. However this new code does work with the regular connection flow. --- discord/abc.py | 7 +++--- discord/state.py | 4 ++-- discord/voice_client.py | 52 +++++++++++++++++++++++------------------ 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/discord/abc.py b/discord/abc.py index 07d720937..863d6d765 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -796,7 +796,7 @@ class Callable(metaclass=abc.ABCMeta): raise NotImplementedError @asyncio.coroutine - def connect(self, *, timeout=10.0, reconnect=True): + def connect(self, *, timeout=60.0, reconnect=True): """|coro| Connects to voice and creates a :class:`VoiceClient` to establish @@ -805,8 +805,7 @@ class Callable(metaclass=abc.ABCMeta): Parameters ----------- timeout: float - The timeout in seconds to wait for the - initial handshake to be completed. + The timeout in seconds to wait for the voice endpoint. reconnect: bool Whether the bot should automatically attempt a reconnect if a part of the handshake fails @@ -833,6 +832,7 @@ class Callable(metaclass=abc.ABCMeta): raise ClientException('Already connected to a voice channel.') voice = VoiceClient(state=state, timeout=timeout, channel=self) + state._add_voice_client(key_id, voice) try: yield from voice.connect(reconnect=reconnect) @@ -844,5 +844,4 @@ class Callable(metaclass=abc.ABCMeta): pass raise e # re-raise - state._add_voice_client(key_id, voice) return voice diff --git a/discord/state.py b/discord/state.py index 10f6d16e9..1a2f9b98b 100644 --- a/discord/state.py +++ b/discord/state.py @@ -695,8 +695,8 @@ class ConnectionState: key_id = int(data['channel_id']) vc = self._get_voice_client(key_id) - if vc is not None and vc.is_connected(): - compat.create_task(vc._switch_regions()) + if vc is not None: + compat.create_task(vc._create_socket(key_id, data)) def parse_typing_start(self, data): channel = self.get_channel(int(data['channel_id'])) diff --git a/discord/voice_client.py b/discord/voice_client.py index 89f5ab0aa..09edfa2e0 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -93,10 +93,13 @@ class VoiceClient: self.channel = channel self.main_ws = None self.timeout = timeout + self.ws = None self.loop = state.loop self._state = state # this will be used in the AudioPlayer thread self._connected = threading.Event() + self._handshake_complete = asyncio.Event(loop=self.loop) + self._connections = 0 self.sequence = 0 self.timestamp = 0 @@ -135,38 +138,21 @@ class VoiceClient: self.main_ws = ws = state._get_websocket(guild_id) self._connections += 1 - def session_id_found(data): - user_id = data.get('user_id', 0) - _guild_id = data.get(key_name) - return int(user_id) == state.self_id and int(_guild_id) == key_id - - # register the futures for waiting - session_id_future = ws.wait_for('VOICE_STATE_UPDATE', session_id_found) - voice_data_future = ws.wait_for('VOICE_SERVER_UPDATE', lambda d: int(d.get(key_name, 0)) == key_id) - # request joining yield from ws.voice_state(guild_id, channel_id) try: - session_id_data = yield from asyncio.wait_for(session_id_future, timeout=self.timeout, loop=self.loop) - data = yield from asyncio.wait_for(voice_data_future, timeout=self.timeout, loop=state.loop) + yield from asyncio.wait_for(self._handshake_complete.wait(), timeout=self.timeout, loop=self.loop) except asyncio.TimeoutError as e: yield from ws.voice_state(guild_id, None, self_mute=True) raise e - self.session_id = session_id_data.get('session_id') - self.server_id = data.get(key_name) - self.token = data.get('token') - self.endpoint = data.get('endpoint', '').replace(':80', '') - self.endpoint_ip = socket.gethostbyname(self.endpoint) - self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.socket.setblocking(False) - log.info('Voice handshake complete. Endpoint found %s (IP: %s)', self.endpoint, self.endpoint_ip) @asyncio.coroutine def terminate_handshake(self, *, remove=False): guild_id, _ = self.channel._get_voice_state_pair() + self._handshake_complete.clear() yield from self.main_ws.voice_state(guild_id, None, self_mute=True) if remove: @@ -174,10 +160,30 @@ class VoiceClient: self._state._remove_voice_client(key_id) @asyncio.coroutine - def _switch_regions(self): - # just reconnect when we're requested to switch voice regions - # signal the reconnect loop - yield from self.ws.close(1006) + def _create_socket(self, server_id, data): + self._connected.clear() + self.session_id = self.main_ws.session_id + self.server_id = server_id + self.token = data.get('token') + endpoint = data.get('endpoint') + + if endpoint is None or self.token is None: + log.warning('Awaiting endpoint... This requires waiting. ' \ + 'If timeout occurred considering raising the timeout and reconnecting.') + return + + self.endpoint = endpoint.replace(':80', '') + self.endpoint_ip = socket.gethostbyname(self.endpoint) + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.socket.setblocking(False) + + if self._handshake_complete.is_set(): + # terminate the websocket and handle the reconnect loop if necessary. + self._handshake_complete.clear() + yield from self.ws.close(1006) + return + + self._handshake_complete.set() @asyncio.coroutine def connect(self, *, reconnect=True, _tries=0, do_handshake=True):