From 44284ae107c8cd504aa02a6ac650039d438079af Mon Sep 17 00:00:00 2001 From: Imayhaveborkedit Date: Thu, 28 Sep 2023 17:51:22 -0400 Subject: [PATCH] Rewrite voice connection internals --- discord/abc.py | 4 +- discord/gateway.py | 74 +++-- discord/player.py | 30 +- discord/voice_client.py | 310 +++++---------------- discord/voice_state.py | 596 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 735 insertions(+), 279 deletions(-) create mode 100644 discord/voice_state.py diff --git a/discord/abc.py b/discord/abc.py index 71eaff6ab..4c1f24618 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -1842,7 +1842,7 @@ class Connectable(Protocol): async def connect( self, *, - timeout: float = 60.0, + timeout: float = 30.0, reconnect: bool = True, cls: Callable[[Client, Connectable], T] = VoiceClient, self_deaf: bool = False, @@ -1858,7 +1858,7 @@ class Connectable(Protocol): Parameters ----------- timeout: :class:`float` - The timeout in seconds to wait for the voice endpoint. + The timeout in seconds to wait the connection to complete. reconnect: :class:`bool` Whether the bot should automatically attempt a reconnect if a part of the handshake fails diff --git a/discord/gateway.py b/discord/gateway.py index 551e36a55..4f98bc2c1 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -34,7 +34,7 @@ import threading import traceback import zlib -from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar +from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple import aiohttp import yarl @@ -59,7 +59,7 @@ if TYPE_CHECKING: from .client import Client from .state import ConnectionState - from .voice_client import VoiceClient + from .voice_state import VoiceConnectionState class ReconnectWebSocket(Exception): @@ -797,7 +797,7 @@ class DiscordVoiceWebSocket: if TYPE_CHECKING: thread_id: int - _connection: VoiceClient + _connection: VoiceConnectionState gateway: str _max_heartbeat_timeout: float @@ -866,16 +866,21 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) @classmethod - async def from_client( - cls, client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None + async def from_connection_state( + cls, + state: VoiceConnectionState, + *, + resume: bool = False, + hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None, ) -> Self: """Creates a voice websocket for the :class:`VoiceClient`.""" - gateway = 'wss://' + client.endpoint + '/?v=4' + gateway = f'wss://{state.endpoint}/?v=4' + client = state.voice_client http = client._state.http socket = await http.ws_connect(gateway, compress=15) ws = cls(socket, loop=client.loop, hook=hook) ws.gateway = gateway - ws._connection = client + ws._connection = state ws._max_heartbeat_timeout = 60.0 ws.thread_id = threading.get_ident() @@ -951,29 +956,49 @@ class DiscordVoiceWebSocket: state.voice_port = data['port'] state.endpoint_ip = data['ip'] + _log.debug('Connecting to voice socket') + await self.loop.sock_connect(state.socket, (state.endpoint_ip, state.voice_port)) + + state.ip, state.port = await self.discover_ip() + # there *should* always be at least one supported mode (xsalsa20_poly1305) + modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes] + _log.debug('received supported encryption modes: %s', ', '.join(modes)) + + mode = modes[0] + await self.select_protocol(state.ip, state.port, mode) + _log.debug('selected the voice protocol for use (%s)', mode) + + async def discover_ip(self) -> Tuple[str, int]: + state = self._connection packet = bytearray(74) struct.pack_into('>H', packet, 0, 1) # 1 = Send struct.pack_into('>H', packet, 2, 70) # 70 = Length struct.pack_into('>I', packet, 4, state.ssrc) - state.socket.sendto(packet, (state.endpoint_ip, state.voice_port)) - recv = await self.loop.sock_recv(state.socket, 74) - _log.debug('received packet in initial_connection: %s', recv) + + _log.debug('Sending ip discovery packet') + await self.loop.sock_sendall(state.socket, packet) + + fut: asyncio.Future[bytes] = self.loop.create_future() + + def get_ip_packet(data: bytes): + if data[1] == 0x02 and len(data) == 74: + self.loop.call_soon_threadsafe(fut.set_result, data) + + fut.add_done_callback(lambda f: state.remove_socket_listener(get_ip_packet)) + state.add_socket_listener(get_ip_packet) + recv = await fut + + _log.debug('Received ip discovery packet: %s', recv) # the ip is ascii starting at the 8th byte and ending at the first null ip_start = 8 ip_end = recv.index(0, ip_start) - state.ip = recv[ip_start:ip_end].decode('ascii') + ip = recv[ip_start:ip_end].decode('ascii') - state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0] - _log.debug('detected ip: %s port: %s', state.ip, state.port) + port = struct.unpack_from('>H', recv, len(recv) - 2)[0] + _log.debug('detected ip: %s port: %s', ip, port) - # there *should* always be at least one supported mode (xsalsa20_poly1305) - modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes] - _log.debug('received supported encryption modes: %s', ", ".join(modes)) - - mode = modes[0] - await self.select_protocol(state.ip, state.port, mode) - _log.debug('selected the voice protocol for use (%s)', mode) + return ip, port @property def latency(self) -> float: @@ -995,9 +1020,8 @@ class DiscordVoiceWebSocket: self.secret_key = self._connection.secret_key = data['secret_key'] # Send a speak command with the "not speaking" state. - # This also tells Discord our SSRC value, which Discord requires - # before sending any voice data (and is the real reason why we - # call this here). + # This also tells Discord our SSRC value, which Discord requires before + # sending any voice data (and is the real reason why we call this here). await self.speak(SpeakingState.none) async def poll_event(self) -> None: @@ -1006,10 +1030,10 @@ class DiscordVoiceWebSocket: if msg.type is aiohttp.WSMsgType.TEXT: await self.received_message(utils._from_json(msg.data)) elif msg.type is aiohttp.WSMsgType.ERROR: - _log.debug('Received %s', msg) + _log.debug('Received voice %s', msg) raise ConnectionClosed(self.ws, shard_id=None) from msg.data elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): - _log.debug('Received %s', msg) + _log.debug('Received voice %s', msg) raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) async def close(self, code: int = 1000) -> None: diff --git a/discord/player.py b/discord/player.py index b9106f738..147c0628a 100644 --- a/discord/player.py +++ b/discord/player.py @@ -703,7 +703,6 @@ class AudioPlayer(threading.Thread): self._resumed: threading.Event = threading.Event() self._resumed.set() # we are not paused self._current_error: Optional[Exception] = None - self._connected: threading.Event = client._connected self._lock: threading.Lock = threading.Lock() if after is not None and not callable(after): @@ -714,7 +713,8 @@ class AudioPlayer(threading.Thread): self._start = time.perf_counter() # getattr lookup speed ups - play_audio = self.client.send_audio_packet + client = self.client + play_audio = client.send_audio_packet self._speak(SpeakingState.voice) while not self._end.is_set(): @@ -725,22 +725,28 @@ class AudioPlayer(threading.Thread): self._resumed.wait() continue - # are we disconnected from voice? - if not self._connected.is_set(): - # wait until we are connected - self._connected.wait() - # reset our internal data - self.loops = 0 - self._start = time.perf_counter() - - self.loops += 1 data = self.source.read() if not data: self.stop() break + # are we disconnected from voice? + if not client.is_connected(): + _log.debug('Not connected, waiting for %ss...', client.timeout) + # wait until we are connected, but not forever + connected = client.wait_until_connected(client.timeout) + if self._end.is_set() or not connected: + _log.debug('Aborting playback') + return + _log.debug('Reconnected, resuming playback') + self._speak(SpeakingState.voice) + # reset our internal data + self.loops = 0 + self._start = time.perf_counter() + play_audio(data, encode=not self.source.is_opus()) + self.loops += 1 next_time = self._start + self.DELAY * self.loops delay = max(0, self.DELAY + (next_time - time.perf_counter())) time.sleep(delay) @@ -792,7 +798,7 @@ class AudioPlayer(threading.Thread): def is_paused(self) -> bool: return not self._end.is_set() and not self._resumed.is_set() - def _set_source(self, source: AudioSource) -> None: + def set_source(self, source: AudioSource) -> None: with self._lock: self.pause(update_speaking=False) self.source = source diff --git a/discord/voice_client.py b/discord/voice_client.py index 8309218a1..d991f1476 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -20,40 +20,24 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - - -Some documentation to refer to: - -- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID. -- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE. -- We pull the session_id from VOICE_STATE_UPDATE. -- We pull the token, endpoint and server_id from VOICE_SERVER_UPDATE. -- Then we initiate the voice web socket (vWS) pointing to the endpoint. -- We send opcode 0 with the user_id, server_id, session_id and token using the vWS. -- The vWS sends back opcode 2 with an ssrc, port, modes(array) and heartbeat_interval. -- We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE. -- Then we send our IP and port via vWS with opcode 1. -- When that's all done, we receive opcode 4 from the vWS. -- Finally we can transmit data to endpoint:port. """ from __future__ import annotations import asyncio -import socket import logging import struct -import threading from typing import Any, Callable, List, Optional, TYPE_CHECKING, Tuple, Union -from . import opus, utils -from .backoff import ExponentialBackoff +from . import opus from .gateway import * -from .errors import ClientException, ConnectionClosed +from .errors import ClientException from .player import AudioPlayer, AudioSource from .utils import MISSING +from .voice_state import VoiceConnectionState if TYPE_CHECKING: + from .gateway import DiscordVoiceWebSocket from .client import Client from .guild import Guild from .state import ConnectionState @@ -226,12 +210,6 @@ class VoiceClient(VoiceProtocol): """ channel: VocalGuildChannel - endpoint_ip: str - voice_port: int - ip: str - port: int - secret_key: List[int] - ssrc: int def __init__(self, client: Client, channel: abc.Connectable) -> None: if not has_nacl: @@ -239,29 +217,18 @@ class VoiceClient(VoiceProtocol): super().__init__(client, channel) state = client._connection - self.token: str = MISSING self.server_id: int = MISSING self.socket = MISSING self.loop: asyncio.AbstractEventLoop = state.loop self._state: ConnectionState = state - # this will be used in the AudioPlayer thread - self._connected: threading.Event = threading.Event() - self._handshaking: bool = False - self._potentially_reconnecting: bool = False - self._voice_state_complete: asyncio.Event = asyncio.Event() - self._voice_server_complete: asyncio.Event = asyncio.Event() - - self.mode: str = MISSING - self._connections: int = 0 self.sequence: int = 0 self.timestamp: int = 0 - self.timeout: float = 0 - self._runner: asyncio.Task = MISSING self._player: Optional[AudioPlayer] = None self.encoder: Encoder = MISSING self._lite_nonce: int = 0 - self.ws: DiscordVoiceWebSocket = MISSING + + self._connection: VoiceConnectionState = self.create_connection_state() warn_nacl: bool = not has_nacl supported_modes: Tuple[SupportedModes, ...] = ( @@ -280,6 +247,38 @@ class VoiceClient(VoiceProtocol): """:class:`ClientUser`: The user connected to voice (i.e. ourselves).""" return self._state.user # type: ignore + @property + def session_id(self) -> Optional[str]: + return self._connection.session_id + + @property + def token(self) -> Optional[str]: + return self._connection.token + + @property + def endpoint(self) -> Optional[str]: + return self._connection.endpoint + + @property + def ssrc(self) -> int: + return self._connection.ssrc + + @property + def mode(self) -> SupportedModes: + return self._connection.mode + + @property + def secret_key(self) -> List[int]: + return self._connection.secret_key + + @property + def ws(self) -> DiscordVoiceWebSocket: + return self._connection.ws + + @property + def timeout(self) -> float: + return self._connection.timeout + def checked_add(self, attr: str, value: int, limit: int) -> None: val = getattr(self, attr) if val + value > limit: @@ -289,149 +288,23 @@ class VoiceClient(VoiceProtocol): # connection related + def create_connection_state(self) -> VoiceConnectionState: + return VoiceConnectionState(self) + async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None: - self.session_id: str = data['session_id'] - channel_id = data['channel_id'] - - if not self._handshaking or self._potentially_reconnecting: - # If we're done handshaking then we just need to update ourselves - # If we're potentially reconnecting due to a 4014, then we need to differentiate - # a channel move and an actual force disconnect - if channel_id is None: - # We're being disconnected so cleanup - await self.disconnect() - else: - self.channel = channel_id and self.guild.get_channel(int(channel_id)) # type: ignore - else: - self._voice_state_complete.set() + await self._connection.voice_state_update(data) async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None: - if self._voice_server_complete.is_set(): - _log.warning('Ignoring extraneous voice server update.') - return - - self.token = data['token'] - self.server_id = int(data['guild_id']) - endpoint = data.get('endpoint') - - if endpoint is None or self.token is None: - _log.warning( - 'Awaiting endpoint... This requires waiting. ' - 'If timeout occurred considering raising the timeout and reconnecting.' - ) - return - - self.endpoint, _, _ = endpoint.rpartition(':') - if self.endpoint.startswith('wss://'): - # Just in case, strip it off since we're going to add it later - self.endpoint: str = self.endpoint[6:] - - # This gets set later - self.endpoint_ip = MISSING - - self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.socket.setblocking(False) - - if not self._handshaking: - # If we're not handshaking then we need to terminate our previous connection in the websocket - await self.ws.close(4000) - return - - self._voice_server_complete.set() - - async def voice_connect(self, self_deaf: bool = False, self_mute: bool = False) -> None: - await self.channel.guild.change_voice_state(channel=self.channel, self_deaf=self_deaf, self_mute=self_mute) - - async def voice_disconnect(self) -> None: - _log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id) - await self.channel.guild.change_voice_state(channel=None) - - def prepare_handshake(self) -> None: - self._voice_state_complete.clear() - self._voice_server_complete.clear() - self._handshaking = True - _log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1) - self._connections += 1 - - def finish_handshake(self) -> None: - _log.info('Voice handshake complete. Endpoint found %s', self.endpoint) - self._handshaking = False - self._voice_server_complete.clear() - self._voice_state_complete.clear() - - async def connect_websocket(self) -> DiscordVoiceWebSocket: - ws = await DiscordVoiceWebSocket.from_client(self) - self._connected.clear() - while ws.secret_key is None: - await ws.poll_event() - self._connected.set() - return ws + await self._connection.voice_server_update(data) async def connect(self, *, reconnect: bool, timeout: float, self_deaf: bool = False, self_mute: bool = False) -> None: - _log.info('Connecting to voice...') - self.timeout = timeout - - for i in range(5): - self.prepare_handshake() - - # This has to be created before we start the flow. - futures = [ - self._voice_state_complete.wait(), - self._voice_server_complete.wait(), - ] - - # Start the connection flow - await self.voice_connect(self_deaf=self_deaf, self_mute=self_mute) - - try: - await utils.sane_wait_for(futures, timeout=timeout) - except asyncio.TimeoutError: - await self.disconnect(force=True) - raise - - self.finish_handshake() - - try: - self.ws = await self.connect_websocket() - break - except (ConnectionClosed, asyncio.TimeoutError): - if reconnect: - _log.exception('Failed to connect to voice... Retrying...') - await asyncio.sleep(1 + i * 2.0) - await self.voice_disconnect() - continue - else: - raise - - if self._runner is MISSING: - self._runner = self.client.loop.create_task(self.poll_voice_ws(reconnect)) - - async def potential_reconnect(self) -> bool: - # Attempt to stop the player thread from playing early - self._connected.clear() - self.prepare_handshake() - self._potentially_reconnecting = True - try: - # We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected - await asyncio.wait_for(self._voice_server_complete.wait(), timeout=self.timeout) - except asyncio.TimeoutError: - self._potentially_reconnecting = False - await self.disconnect(force=True) - return False - - self.finish_handshake() - self._potentially_reconnecting = False - - if self.ws: - _log.debug("Closing existing voice websocket") - await self.ws.close() + await self._connection.connect( + reconnect=reconnect, timeout=timeout, self_deaf=self_deaf, self_mute=self_mute, resume=False + ) - try: - self.ws = await self.connect_websocket() - except (ConnectionClosed, asyncio.TimeoutError): - return False - else: - return True + def wait_until_connected(self, timeout: Optional[float] = 30.0) -> bool: + self._connection.wait(timeout) + return self._connection.is_connected() @property def latency(self) -> float: @@ -442,7 +315,7 @@ class VoiceClient(VoiceProtocol): .. versionadded:: 1.4 """ - ws = self.ws + ws = self._connection.ws return float("inf") if not ws else ws.latency @property @@ -451,72 +324,19 @@ class VoiceClient(VoiceProtocol): .. versionadded:: 1.4 """ - ws = self.ws + ws = self._connection.ws return float("inf") if not ws else ws.average_latency - async def poll_voice_ws(self, reconnect: bool) -> None: - backoff = ExponentialBackoff() - while True: - try: - await self.ws.poll_event() - except (ConnectionClosed, asyncio.TimeoutError) as exc: - if isinstance(exc, ConnectionClosed): - # The following close codes are undocumented so I will document them here. - # 1000 - normal closure (obviously) - # 4014 - voice channel has been deleted. - # 4015 - voice server has crashed - if exc.code in (1000, 4015): - _log.info('Disconnecting from voice normally, close code %d.', exc.code) - await self.disconnect() - break - if exc.code == 4014: - _log.info('Disconnected from voice by force... potentially reconnecting.') - successful = await self.potential_reconnect() - if not successful: - _log.info('Reconnect was unsuccessful, disconnecting from voice normally...') - await self.disconnect() - break - else: - continue - - if not reconnect: - await self.disconnect() - raise - - retry = backoff.delay() - _log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry) - self._connected.clear() - await asyncio.sleep(retry) - await self.voice_disconnect() - try: - await self.connect(reconnect=True, timeout=self.timeout) - except asyncio.TimeoutError: - # at this point we've retried 5 times... let's continue the loop. - _log.warning('Could not connect to voice... Retrying...') - continue - async def disconnect(self, *, force: bool = False) -> None: """|coro| Disconnects this voice client from voice. """ - if not force and not self.is_connected(): - return - self.stop() - self._connected.clear() - - try: - if self.ws: - await self.ws.close() - - await self.voice_disconnect() - finally: - self.cleanup() - if self.socket: - self.socket.close() + await self._connection.disconnect(force=force) + self.cleanup() - async def move_to(self, channel: Optional[abc.Snowflake]) -> None: + async def move_to(self, channel: Optional[abc.Snowflake], *, timeout: Optional[float] = 30.0) -> None: """|coro| Moves you to a different voice channel. @@ -525,12 +345,22 @@ class VoiceClient(VoiceProtocol): ----------- channel: Optional[:class:`abc.Snowflake`] The channel to move to. Must be a voice channel. + timeout: Optional[:class:`float`] + How long to wait for the move to complete. + + .. versionadded:: 2.4 + + Raises + ------- + asyncio.TimeoutError + The move did not complete in time, but may still be ongoing. """ - await self.channel.guild.change_voice_state(channel=channel) + await self._connection.move_to(channel) + await self._connection.wait_async(timeout) def is_connected(self) -> bool: """Indicates if the voice client is connected to voice.""" - return self._connected.is_set() + return self._connection.is_connected() # audio related @@ -703,7 +533,7 @@ class VoiceClient(VoiceProtocol): if self._player is None: raise ValueError('Not playing anything.') - self._player._set_source(value) + self._player.set_source(value) def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: """Sends an audio packet composed of the data. @@ -732,8 +562,8 @@ class VoiceClient(VoiceProtocol): encoded_data = data packet = self._get_voice_packet(encoded_data) try: - self.socket.sendto(packet, (self.endpoint_ip, self.voice_port)) - except BlockingIOError: - _log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp) + self._connection.send_packet(packet) + except OSError: + _log.info('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp) self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295) diff --git a/discord/voice_state.py b/discord/voice_state.py new file mode 100644 index 000000000..f8ab1fa54 --- /dev/null +++ b/discord/voice_state.py @@ -0,0 +1,596 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. + + +Some documentation to refer to: + +- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID. +- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE. +- We pull the session_id from VOICE_STATE_UPDATE. +- We pull the token, endpoint and server_id from VOICE_SERVER_UPDATE. +- Then we initiate the voice web socket (vWS) pointing to the endpoint. +- We send opcode 0 with the user_id, server_id, session_id and token using the vWS. +- The vWS sends back opcode 2 with an ssrc, port, modes(array) and heartbeat_interval. +- We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE. +- Then we send our IP and port via vWS with opcode 1. +- When that's all done, we receive opcode 4 from the vWS. +- Finally we can transmit data to endpoint:port. +""" + +from __future__ import annotations + +import select +import socket +import asyncio +import logging +import threading + +import async_timeout + +from typing import TYPE_CHECKING, Optional, Dict, List, Callable, Coroutine, Any, Tuple + +from .enums import Enum +from .utils import MISSING, sane_wait_for +from .errors import ConnectionClosed +from .backoff import ExponentialBackoff +from .gateway import DiscordVoiceWebSocket + +if TYPE_CHECKING: + from . import abc + from .guild import Guild + from .user import ClientUser + from .member import VoiceState + from .voice_client import VoiceClient + + from .types.voice import ( + GuildVoiceState as GuildVoiceStatePayload, + VoiceServerUpdate as VoiceServerUpdatePayload, + SupportedModes, + ) + + WebsocketHook = Optional[Callable[['VoiceConnectionState', Dict[str, Any]], Coroutine[Any, Any, Any]]] + SocketReaderCallback = Callable[[bytes], Any] + + +__all__ = ('VoiceConnectionState',) + +_log = logging.getLogger(__name__) + + +class SocketReader(threading.Thread): + def __init__(self, state: VoiceConnectionState) -> None: + super().__init__(daemon=True, name=f'voice-socket-reader:{id(self):#x}') + self.state: VoiceConnectionState = state + self._callbacks: List[SocketReaderCallback] = [] + self._running = threading.Event() + self._end = threading.Event() + # If we have paused reading due to having no callbacks + self._idle_paused: bool = True + + def register(self, callback: SocketReaderCallback) -> None: + self._callbacks.append(callback) + if self._idle_paused: + self._idle_paused = False + self._running.set() + + def unregister(self, callback: SocketReaderCallback) -> None: + try: + self._callbacks.remove(callback) + except ValueError: + pass + else: + if not self._callbacks and self._running.is_set(): + # If running is not set, we are either explicitly paused and + # should be explicitly resumed, or we are already idle paused + self._idle_paused = True + self._running.clear() + + def pause(self) -> None: + self._idle_paused = False + self._running.clear() + + def resume(self, *, force: bool = False) -> None: + if self._running.is_set(): + return + # Don't resume if there are no callbacks registered + if not force and not self._callbacks: + # We tried to resume but there was nothing to do, so resume when ready + self._idle_paused = True + return + self._idle_paused = False + self._running.set() + + def stop(self) -> None: + self._end.set() + self._running.set() + + def run(self) -> None: + self._end.clear() + self._running.set() + try: + self._do_run() + except Exception: + _log.exception('Error in %s', self) + finally: + self.stop() + self._running.clear() + self._callbacks.clear() + + def _do_run(self) -> None: + while not self._end.is_set(): + if not self._running.is_set(): + self._running.wait() + continue + + # Since this socket is a non blocking socket, select has to be used to wait on it for reading. + try: + readable, _, _ = select.select([self.state.socket], [], [], 30) + except (ValueError, TypeError): + # The socket is either closed or doesn't exist at the moment + continue + + if not readable: + continue + + try: + data = self.state.socket.recv(2048) + except OSError: + _log.debug('Error reading from socket in %s, this should be safe to ignore', self, exc_info=True) + else: + for cb in self._callbacks: + try: + cb(data) + except Exception: + _log.exception('Error calling %s in %s', cb, self) + + +class ConnectionFlowState(Enum): + """Enum representing voice connection flow state.""" + + # fmt: off + disconnected = 0 + set_guild_voice_state = 1 + got_voice_state_update = 2 + got_voice_server_update = 3 + got_both_voice_updates = 4 + websocket_connected = 5 + got_websocket_ready = 6 + got_ip_discovery = 7 + connected = 8 + # fmt: on + + +class VoiceConnectionState: + """Represents the internal state of a voice connection.""" + + def __init__(self, voice_client: VoiceClient, *, hook: Optional[WebsocketHook] = None) -> None: + self.voice_client = voice_client + self.hook = hook + + self.timeout: float = 30.0 + self.reconnect: bool = True + self.self_deaf: bool = False + self.self_mute: bool = False + self.token: Optional[str] = None + self.session_id: Optional[str] = None + self.endpoint: Optional[str] = None + self.endpoint_ip: Optional[str] = None + self.server_id: Optional[int] = None + self.ip: Optional[str] = None + self.port: Optional[int] = None + self.voice_port: Optional[int] = None + self.secret_key: List[int] = MISSING + self.ssrc: int = MISSING + self.mode: SupportedModes = MISSING + self.socket: socket.socket = MISSING + self.ws: DiscordVoiceWebSocket = MISSING + + self._state: ConnectionFlowState = ConnectionFlowState.disconnected + self._expecting_disconnect: bool = False + self._connected = threading.Event() + self._state_event = asyncio.Event() + self._runner: Optional[asyncio.Task] = None + self._connector: Optional[asyncio.Task] = None + self._socket_reader = SocketReader(self) + self._socket_reader.start() + + @property + def state(self) -> ConnectionFlowState: + return self._state + + @state.setter + def state(self, state: ConnectionFlowState) -> None: + if state is not self._state: + _log.debug('Connection state changed to %s', state.name) + self._state = state + self._state_event.set() + self._state_event.clear() + + if state is ConnectionFlowState.connected: + self._connected.set() + else: + self._connected.clear() + + @property + def guild(self) -> Guild: + return self.voice_client.guild + + @property + def user(self) -> ClientUser: + return self.voice_client.user + + @property + def supported_modes(self) -> Tuple[SupportedModes, ...]: + return self.voice_client.supported_modes + + @property + def self_voice_state(self) -> Optional[VoiceState]: + return self.guild.me.voice + + async def voice_state_update(self, data: GuildVoiceStatePayload) -> None: + channel_id = data['channel_id'] + + if channel_id is None: + # If we know we're going to get a voice_state_update where we have no channel due to + # being in the reconnect flow, we ignore it. Otherwise, it probably wasn't from us. + if self._expecting_disconnect: + self._expecting_disconnect = False + else: + _log.debug('We were externally disconnected from voice.') + await self.disconnect() + + return + + self.session_id = data['session_id'] + + # we got the event while connecting + if self.state in (ConnectionFlowState.set_guild_voice_state, ConnectionFlowState.got_voice_server_update): + if self.state is ConnectionFlowState.set_guild_voice_state: + self.state = ConnectionFlowState.got_voice_state_update + else: + self.state = ConnectionFlowState.got_both_voice_updates + return + + if self.state is ConnectionFlowState.connected: + self.voice_client.channel = channel_id and self.guild.get_channel(int(channel_id)) # type: ignore + + elif self.state is not ConnectionFlowState.disconnected: + if channel_id != self.voice_client.channel.id: + # For some unfortunate reason we were moved during the connection flow + _log.info('Handling channel move while connecting...') + + self.voice_client.channel = channel_id and self.guild.get_channel(int(channel_id)) # type: ignore + + await self.soft_disconnect(with_state=ConnectionFlowState.got_voice_state_update) + await self.connect( + reconnect=self.reconnect, + timeout=self.timeout, + self_deaf=(self.self_voice_state or self).self_deaf, + self_mute=(self.self_voice_state or self).self_mute, + resume=False, + wait=False, + ) + else: + _log.debug('Ignoring unexpected voice_state_update event') + + async def voice_server_update(self, data: VoiceServerUpdatePayload) -> None: + self.token = data['token'] + self.server_id = int(data['guild_id']) + endpoint = data.get('endpoint') + + if self.token is None or endpoint is None: + _log.warning( + 'Awaiting endpoint... This requires waiting. ' + 'If timeout occurred considering raising the timeout and reconnecting.' + ) + return + + self.endpoint, _, _ = endpoint.rpartition(':') + if self.endpoint.startswith('wss://'): + # Just in case, strip it off since we're going to add it later + self.endpoint = self.endpoint[6:] + + # we got the event while connecting + if self.state in (ConnectionFlowState.set_guild_voice_state, ConnectionFlowState.got_voice_state_update): + # This gets set after READY is received + self.endpoint_ip = MISSING + self._create_socket() + + if self.state is ConnectionFlowState.set_guild_voice_state: + self.state = ConnectionFlowState.got_voice_server_update + else: + self.state = ConnectionFlowState.got_both_voice_updates + + elif self.state is ConnectionFlowState.connected: + _log.debug('Voice server update, closing old voice websocket') + await self.ws.close(4014) + self.state = ConnectionFlowState.got_voice_server_update + + elif self.state is not ConnectionFlowState.disconnected: + _log.debug('Unexpected server update event, attempting to handle') + + await self.soft_disconnect(with_state=ConnectionFlowState.got_voice_server_update) + await self.connect( + reconnect=self.reconnect, + timeout=self.timeout, + self_deaf=(self.self_voice_state or self).self_deaf, + self_mute=(self.self_voice_state or self).self_mute, + resume=False, + wait=False, + ) + self._create_socket() + + async def connect( + self, *, reconnect: bool, timeout: float, self_deaf: bool, self_mute: bool, resume: bool, wait: bool = True + ) -> None: + if self._connector: + self._connector.cancel() + self._connector = None + + if self._runner: + self._runner.cancel() + self._runner = None + + self.timeout = timeout + self.reconnect = reconnect + self._connector = self.voice_client.loop.create_task( + self._wrap_connect(reconnect, timeout, self_deaf, self_mute, resume), name='Voice connector' + ) + if wait: + await self._connector + + async def _wrap_connect(self, *args: Any) -> None: + try: + await self._connect(*args) + except asyncio.CancelledError: + _log.debug('Cancelling voice connection') + await self.soft_disconnect() + raise + except asyncio.TimeoutError: + _log.info('Timed out connecting to voice') + await self.disconnect() + raise + except Exception: + _log.exception('Error connecting to voice... disconnecting') + await self.disconnect() + raise + + async def _connect(self, reconnect: bool, timeout: float, self_deaf: bool, self_mute: bool, resume: bool) -> None: + _log.info('Connecting to voice...') + + async with async_timeout.timeout(timeout): + for i in range(5): + _log.info('Starting voice handshake... (connection attempt %d)', i + 1) + + await self._voice_connect(self_deaf=self_deaf, self_mute=self_mute) + # Setting this unnecessarily will break reconnecting + if self.state is ConnectionFlowState.disconnected: + self.state = ConnectionFlowState.set_guild_voice_state + + await self._wait_for_state(ConnectionFlowState.got_both_voice_updates) + + _log.info('Voice handshake complete. Endpoint found: %s', self.endpoint) + + try: + self.ws = await self._connect_websocket(resume) + await self._handshake_websocket() + break + except ConnectionClosed: + if reconnect: + wait = 1 + i * 2.0 + _log.exception('Failed to connect to voice... Retrying in %ss...', wait) + await self.disconnect(cleanup=False) + await asyncio.sleep(wait) + continue + else: + await self.disconnect() + raise + + _log.info('Voice connection complete.') + + if not self._runner: + self._runner = self.voice_client.loop.create_task(self._poll_voice_ws(reconnect), name='Voice websocket poller') + + async def disconnect(self, *, force: bool = True, cleanup: bool = True) -> None: + if not force and not self.is_connected(): + return + + try: + if self.ws: + await self.ws.close() + await self._voice_disconnect() + except Exception: + _log.debug('Ignoring exception disconnecting from voice', exc_info=True) + finally: + self.ip = MISSING + self.port = MISSING + self.state = ConnectionFlowState.disconnected + self._socket_reader.pause() + + # Flip the connected event to unlock any waiters + self._connected.set() + self._connected.clear() + + if cleanup: + self._socket_reader.stop() + self.voice_client.cleanup() + + if self.socket: + self.socket.close() + + async def soft_disconnect(self, *, with_state: ConnectionFlowState = ConnectionFlowState.got_both_voice_updates) -> None: + _log.debug('Soft disconnecting from voice') + # Stop the websocket reader because closing the websocket will trigger an unwanted reconnect + if self._runner: + self._runner.cancel() + self._runner = None + + try: + if self.ws: + await self.ws.close() + except Exception: + _log.debug('Ignoring exception soft disconnecting from voice', exc_info=True) + finally: + self.ip = MISSING + self.port = MISSING + self.state = with_state + self._socket_reader.pause() + + if self.socket: + self.socket.close() + + async def move_to(self, channel: Optional[abc.Snowflake]) -> None: + if channel is None: + await self.disconnect() + return + + await self.voice_client.channel.guild.change_voice_state(channel=channel) + self.state = ConnectionFlowState.set_guild_voice_state + + def wait(self, timeout: Optional[float] = None) -> bool: + return self._connected.wait(timeout) + + async def wait_async(self, timeout: Optional[float] = None) -> None: + await self._wait_for_state(ConnectionFlowState.connected, timeout=timeout) + + def is_connected(self) -> bool: + return self.state is ConnectionFlowState.connected + + def send_packet(self, packet: bytes) -> None: + self.socket.sendall(packet) + + def add_socket_listener(self, callback: SocketReaderCallback) -> None: + _log.debug('Registering socket listener callback %s', callback) + self._socket_reader.register(callback) + + def remove_socket_listener(self, callback: SocketReaderCallback) -> None: + _log.debug('Unregistering socket listener callback %s', callback) + self._socket_reader.unregister(callback) + + async def _wait_for_state( + self, state: ConnectionFlowState, *other_states: ConnectionFlowState, timeout: Optional[float] = None + ) -> None: + states = (state, *other_states) + while True: + if self.state in states: + return + await sane_wait_for([self._state_event.wait()], timeout=timeout) + + async def _voice_connect(self, *, self_deaf: bool = False, self_mute: bool = False) -> None: + channel = self.voice_client.channel + await channel.guild.change_voice_state(channel=channel, self_deaf=self_deaf, self_mute=self_mute) + + async def _voice_disconnect(self) -> None: + _log.info( + 'The voice handshake is being terminated for Channel ID %s (Guild ID %s)', + self.voice_client.channel.id, + self.voice_client.guild.id, + ) + self.state = ConnectionFlowState.disconnected + await self.voice_client.channel.guild.change_voice_state(channel=None) + self._expecting_disconnect = True + + async def _connect_websocket(self, resume: bool) -> DiscordVoiceWebSocket: + ws = await DiscordVoiceWebSocket.from_connection_state(self, resume=resume, hook=self.hook) + self.state = ConnectionFlowState.websocket_connected + return ws + + async def _handshake_websocket(self) -> None: + while not self.ip: + await self.ws.poll_event() + self.state = ConnectionFlowState.got_ip_discovery + while self.ws.secret_key is None: + await self.ws.poll_event() + self.state = ConnectionFlowState.connected + + def _create_socket(self) -> None: + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.socket.setblocking(False) + self._socket_reader.resume() + + async def _poll_voice_ws(self, reconnect: bool) -> None: + backoff = ExponentialBackoff() + while True: + try: + await self.ws.poll_event() + except asyncio.CancelledError: + return + except (ConnectionClosed, asyncio.TimeoutError) as exc: + if isinstance(exc, ConnectionClosed): + # The following close codes are undocumented so I will document them here. + # 1000 - normal closure (obviously) + # 4014 - we were externally disconnected (voice channel deleted, we were moved, etc) + # 4015 - voice server has crashed + if exc.code in (1000, 4015): + _log.info('Disconnecting from voice normally, close code %d.', exc.code) + await self.disconnect() + break + + if exc.code == 4014: + _log.info('Disconnected from voice by force... potentially reconnecting.') + successful = await self._potential_reconnect() + if not successful: + _log.info('Reconnect was unsuccessful, disconnecting from voice normally...') + await self.disconnect() + break + else: + continue + + _log.debug('Not handling close code %s (%s)', exc.code, exc.reason or 'no reason') + + if not reconnect: + await self.disconnect() + raise + + retry = backoff.delay() + _log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry) + await asyncio.sleep(retry) + await self.disconnect(cleanup=False) + + try: + await self._connect( + reconnect=reconnect, + timeout=self.timeout, + self_deaf=(self.self_voice_state or self).self_deaf, + self_mute=(self.self_voice_state or self).self_mute, + resume=False, + ) + except asyncio.TimeoutError: + # at this point we've retried 5 times... let's continue the loop. + _log.warning('Could not connect to voice... Retrying...') + continue + + async def _potential_reconnect(self) -> bool: + try: + await self._wait_for_state( + ConnectionFlowState.got_voice_server_update, ConnectionFlowState.got_both_voice_updates, timeout=self.timeout + ) + except asyncio.TimeoutError: + return False + try: + self.ws = await self._connect_websocket(False) + await self._handshake_websocket() + except (ConnectionClosed, asyncio.TimeoutError): + return False + else: + return True