Browse Source

Rewrite voice connection internals

pull/9585/head
Imayhaveborkedit 2 years ago
committed by GitHub
parent
commit
44284ae107
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      discord/abc.py
  2. 74
      discord/gateway.py
  3. 30
      discord/player.py
  4. 310
      discord/voice_client.py
  5. 596
      discord/voice_state.py

4
discord/abc.py

@ -1842,7 +1842,7 @@ class Connectable(Protocol):
async def connect(
self,
*,
timeout: float = 60.0,
timeout: float = 30.0,
reconnect: bool = True,
cls: Callable[[Client, Connectable], T] = VoiceClient,
self_deaf: bool = False,
@ -1858,7 +1858,7 @@ class Connectable(Protocol):
Parameters
-----------
timeout: :class:`float`
The timeout in seconds to wait for the voice endpoint.
The timeout in seconds to wait the connection to complete.
reconnect: :class:`bool`
Whether the bot should automatically attempt
a reconnect if a part of the handshake fails

74
discord/gateway.py

@ -34,7 +34,7 @@ import threading
import traceback
import zlib
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple
import aiohttp
import yarl
@ -59,7 +59,7 @@ if TYPE_CHECKING:
from .client import Client
from .state import ConnectionState
from .voice_client import VoiceClient
from .voice_state import VoiceConnectionState
class ReconnectWebSocket(Exception):
@ -797,7 +797,7 @@ class DiscordVoiceWebSocket:
if TYPE_CHECKING:
thread_id: int
_connection: VoiceClient
_connection: VoiceConnectionState
gateway: str
_max_heartbeat_timeout: float
@ -866,16 +866,21 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload)
@classmethod
async def from_client(
cls, client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None
async def from_connection_state(
cls,
state: VoiceConnectionState,
*,
resume: bool = False,
hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
) -> Self:
"""Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4'
gateway = f'wss://{state.endpoint}/?v=4'
client = state.voice_client
http = client._state.http
socket = await http.ws_connect(gateway, compress=15)
ws = cls(socket, loop=client.loop, hook=hook)
ws.gateway = gateway
ws._connection = client
ws._connection = state
ws._max_heartbeat_timeout = 60.0
ws.thread_id = threading.get_ident()
@ -951,29 +956,49 @@ class DiscordVoiceWebSocket:
state.voice_port = data['port']
state.endpoint_ip = data['ip']
_log.debug('Connecting to voice socket')
await self.loop.sock_connect(state.socket, (state.endpoint_ip, state.voice_port))
state.ip, state.port = await self.discover_ip()
# there *should* always be at least one supported mode (xsalsa20_poly1305)
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
_log.debug('received supported encryption modes: %s', ', '.join(modes))
mode = modes[0]
await self.select_protocol(state.ip, state.port, mode)
_log.debug('selected the voice protocol for use (%s)', mode)
async def discover_ip(self) -> Tuple[str, int]:
state = self._connection
packet = bytearray(74)
struct.pack_into('>H', packet, 0, 1) # 1 = Send
struct.pack_into('>H', packet, 2, 70) # 70 = Length
struct.pack_into('>I', packet, 4, state.ssrc)
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
recv = await self.loop.sock_recv(state.socket, 74)
_log.debug('received packet in initial_connection: %s', recv)
_log.debug('Sending ip discovery packet')
await self.loop.sock_sendall(state.socket, packet)
fut: asyncio.Future[bytes] = self.loop.create_future()
def get_ip_packet(data: bytes):
if data[1] == 0x02 and len(data) == 74:
self.loop.call_soon_threadsafe(fut.set_result, data)
fut.add_done_callback(lambda f: state.remove_socket_listener(get_ip_packet))
state.add_socket_listener(get_ip_packet)
recv = await fut
_log.debug('Received ip discovery packet: %s', recv)
# the ip is ascii starting at the 8th byte and ending at the first null
ip_start = 8
ip_end = recv.index(0, ip_start)
state.ip = recv[ip_start:ip_end].decode('ascii')
ip = recv[ip_start:ip_end].decode('ascii')
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
_log.debug('detected ip: %s port: %s', state.ip, state.port)
port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
_log.debug('detected ip: %s port: %s', ip, port)
# there *should* always be at least one supported mode (xsalsa20_poly1305)
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
_log.debug('received supported encryption modes: %s', ", ".join(modes))
mode = modes[0]
await self.select_protocol(state.ip, state.port, mode)
_log.debug('selected the voice protocol for use (%s)', mode)
return ip, port
@property
def latency(self) -> float:
@ -995,9 +1020,8 @@ class DiscordVoiceWebSocket:
self.secret_key = self._connection.secret_key = data['secret_key']
# Send a speak command with the "not speaking" state.
# This also tells Discord our SSRC value, which Discord requires
# before sending any voice data (and is the real reason why we
# call this here).
# This also tells Discord our SSRC value, which Discord requires before
# sending any voice data (and is the real reason why we call this here).
await self.speak(SpeakingState.none)
async def poll_event(self) -> None:
@ -1006,10 +1030,10 @@ class DiscordVoiceWebSocket:
if msg.type is aiohttp.WSMsgType.TEXT:
await self.received_message(utils._from_json(msg.data))
elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug('Received %s', msg)
_log.debug('Received voice %s', msg)
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
_log.debug('Received %s', msg)
_log.debug('Received voice %s', msg)
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
async def close(self, code: int = 1000) -> None:

30
discord/player.py

@ -703,7 +703,6 @@ class AudioPlayer(threading.Thread):
self._resumed: threading.Event = threading.Event()
self._resumed.set() # we are not paused
self._current_error: Optional[Exception] = None
self._connected: threading.Event = client._connected
self._lock: threading.Lock = threading.Lock()
if after is not None and not callable(after):
@ -714,7 +713,8 @@ class AudioPlayer(threading.Thread):
self._start = time.perf_counter()
# getattr lookup speed ups
play_audio = self.client.send_audio_packet
client = self.client
play_audio = client.send_audio_packet
self._speak(SpeakingState.voice)
while not self._end.is_set():
@ -725,22 +725,28 @@ class AudioPlayer(threading.Thread):
self._resumed.wait()
continue
# are we disconnected from voice?
if not self._connected.is_set():
# wait until we are connected
self._connected.wait()
# reset our internal data
self.loops = 0
self._start = time.perf_counter()
self.loops += 1
data = self.source.read()
if not data:
self.stop()
break
# are we disconnected from voice?
if not client.is_connected():
_log.debug('Not connected, waiting for %ss...', client.timeout)
# wait until we are connected, but not forever
connected = client.wait_until_connected(client.timeout)
if self._end.is_set() or not connected:
_log.debug('Aborting playback')
return
_log.debug('Reconnected, resuming playback')
self._speak(SpeakingState.voice)
# reset our internal data
self.loops = 0
self._start = time.perf_counter()
play_audio(data, encode=not self.source.is_opus())
self.loops += 1
next_time = self._start + self.DELAY * self.loops
delay = max(0, self.DELAY + (next_time - time.perf_counter()))
time.sleep(delay)
@ -792,7 +798,7 @@ class AudioPlayer(threading.Thread):
def is_paused(self) -> bool:
return not self._end.is_set() and not self._resumed.is_set()
def _set_source(self, source: AudioSource) -> None:
def set_source(self, source: AudioSource) -> None:
with self._lock:
self.pause(update_speaking=False)
self.source = source

310
discord/voice_client.py

@ -20,40 +20,24 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
Some documentation to refer to:
- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID.
- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE.
- We pull the session_id from VOICE_STATE_UPDATE.
- We pull the token, endpoint and server_id from VOICE_SERVER_UPDATE.
- Then we initiate the voice web socket (vWS) pointing to the endpoint.
- We send opcode 0 with the user_id, server_id, session_id and token using the vWS.
- The vWS sends back opcode 2 with an ssrc, port, modes(array) and heartbeat_interval.
- We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE.
- Then we send our IP and port via vWS with opcode 1.
- When that's all done, we receive opcode 4 from the vWS.
- Finally we can transmit data to endpoint:port.
"""
from __future__ import annotations
import asyncio
import socket
import logging
import struct
import threading
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Tuple, Union
from . import opus, utils
from .backoff import ExponentialBackoff
from . import opus
from .gateway import *
from .errors import ClientException, ConnectionClosed
from .errors import ClientException
from .player import AudioPlayer, AudioSource
from .utils import MISSING
from .voice_state import VoiceConnectionState
if TYPE_CHECKING:
from .gateway import DiscordVoiceWebSocket
from .client import Client
from .guild import Guild
from .state import ConnectionState
@ -226,12 +210,6 @@ class VoiceClient(VoiceProtocol):
"""
channel: VocalGuildChannel
endpoint_ip: str
voice_port: int
ip: str
port: int
secret_key: List[int]
ssrc: int
def __init__(self, client: Client, channel: abc.Connectable) -> None:
if not has_nacl:
@ -239,29 +217,18 @@ class VoiceClient(VoiceProtocol):
super().__init__(client, channel)
state = client._connection
self.token: str = MISSING
self.server_id: int = MISSING
self.socket = MISSING
self.loop: asyncio.AbstractEventLoop = state.loop
self._state: ConnectionState = state
# this will be used in the AudioPlayer thread
self._connected: threading.Event = threading.Event()
self._handshaking: bool = False
self._potentially_reconnecting: bool = False
self._voice_state_complete: asyncio.Event = asyncio.Event()
self._voice_server_complete: asyncio.Event = asyncio.Event()
self.mode: str = MISSING
self._connections: int = 0
self.sequence: int = 0
self.timestamp: int = 0
self.timeout: float = 0
self._runner: asyncio.Task = MISSING
self._player: Optional[AudioPlayer] = None
self.encoder: Encoder = MISSING
self._lite_nonce: int = 0
self.ws: DiscordVoiceWebSocket = MISSING
self._connection: VoiceConnectionState = self.create_connection_state()
warn_nacl: bool = not has_nacl
supported_modes: Tuple[SupportedModes, ...] = (
@ -280,6 +247,38 @@ class VoiceClient(VoiceProtocol):
""":class:`ClientUser`: The user connected to voice (i.e. ourselves)."""
return self._state.user # type: ignore
@property
def session_id(self) -> Optional[str]:
return self._connection.session_id
@property
def token(self) -> Optional[str]:
return self._connection.token
@property
def endpoint(self) -> Optional[str]:
return self._connection.endpoint
@property
def ssrc(self) -> int:
return self._connection.ssrc
@property
def mode(self) -> SupportedModes:
return self._connection.mode
@property
def secret_key(self) -> List[int]:
return self._connection.secret_key
@property
def ws(self) -> DiscordVoiceWebSocket:
return self._connection.ws
@property
def timeout(self) -> float:
return self._connection.timeout
def checked_add(self, attr: str, value: int, limit: int) -> None:
val = getattr(self, attr)
if val + value > limit:
@ -289,149 +288,23 @@ class VoiceClient(VoiceProtocol):
# connection related
def create_connection_state(self) -> VoiceConnectionState:
return VoiceConnectionState(self)
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
self.session_id: str = data['session_id']
channel_id = data['channel_id']
if not self._handshaking or self._potentially_reconnecting:
# If we're done handshaking then we just need to update ourselves
# If we're potentially reconnecting due to a 4014, then we need to differentiate
# a channel move and an actual force disconnect
if channel_id is None:
# We're being disconnected so cleanup
await self.disconnect()
else:
self.channel = channel_id and self.guild.get_channel(int(channel_id)) # type: ignore
else:
self._voice_state_complete.set()
await self._connection.voice_state_update(data)
async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
if self._voice_server_complete.is_set():
_log.warning('Ignoring extraneous voice server update.')
return
self.token = data['token']
self.server_id = int(data['guild_id'])
endpoint = data.get('endpoint')
if endpoint is None or self.token is None:
_log.warning(
'Awaiting endpoint... This requires waiting. '
'If timeout occurred considering raising the timeout and reconnecting.'
)
return
self.endpoint, _, _ = endpoint.rpartition(':')
if self.endpoint.startswith('wss://'):
# Just in case, strip it off since we're going to add it later
self.endpoint: str = self.endpoint[6:]
# This gets set later
self.endpoint_ip = MISSING
self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.setblocking(False)
if not self._handshaking:
# If we're not handshaking then we need to terminate our previous connection in the websocket
await self.ws.close(4000)
return
self._voice_server_complete.set()
async def voice_connect(self, self_deaf: bool = False, self_mute: bool = False) -> None:
await self.channel.guild.change_voice_state(channel=self.channel, self_deaf=self_deaf, self_mute=self_mute)
async def voice_disconnect(self) -> None:
_log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
await self.channel.guild.change_voice_state(channel=None)
def prepare_handshake(self) -> None:
self._voice_state_complete.clear()
self._voice_server_complete.clear()
self._handshaking = True
_log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
self._connections += 1
def finish_handshake(self) -> None:
_log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
self._handshaking = False
self._voice_server_complete.clear()
self._voice_state_complete.clear()
async def connect_websocket(self) -> DiscordVoiceWebSocket:
ws = await DiscordVoiceWebSocket.from_client(self)
self._connected.clear()
while ws.secret_key is None:
await ws.poll_event()
self._connected.set()
return ws
await self._connection.voice_server_update(data)
async def connect(self, *, reconnect: bool, timeout: float, self_deaf: bool = False, self_mute: bool = False) -> None:
_log.info('Connecting to voice...')
self.timeout = timeout
for i in range(5):
self.prepare_handshake()
# This has to be created before we start the flow.
futures = [
self._voice_state_complete.wait(),
self._voice_server_complete.wait(),
]
# Start the connection flow
await self.voice_connect(self_deaf=self_deaf, self_mute=self_mute)
try:
await utils.sane_wait_for(futures, timeout=timeout)
except asyncio.TimeoutError:
await self.disconnect(force=True)
raise
self.finish_handshake()
try:
self.ws = await self.connect_websocket()
break
except (ConnectionClosed, asyncio.TimeoutError):
if reconnect:
_log.exception('Failed to connect to voice... Retrying...')
await asyncio.sleep(1 + i * 2.0)
await self.voice_disconnect()
continue
else:
raise
if self._runner is MISSING:
self._runner = self.client.loop.create_task(self.poll_voice_ws(reconnect))
async def potential_reconnect(self) -> bool:
# Attempt to stop the player thread from playing early
self._connected.clear()
self.prepare_handshake()
self._potentially_reconnecting = True
try:
# We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected
await asyncio.wait_for(self._voice_server_complete.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
self._potentially_reconnecting = False
await self.disconnect(force=True)
return False
self.finish_handshake()
self._potentially_reconnecting = False
if self.ws:
_log.debug("Closing existing voice websocket")
await self.ws.close()
await self._connection.connect(
reconnect=reconnect, timeout=timeout, self_deaf=self_deaf, self_mute=self_mute, resume=False
)
try:
self.ws = await self.connect_websocket()
except (ConnectionClosed, asyncio.TimeoutError):
return False
else:
return True
def wait_until_connected(self, timeout: Optional[float] = 30.0) -> bool:
self._connection.wait(timeout)
return self._connection.is_connected()
@property
def latency(self) -> float:
@ -442,7 +315,7 @@ class VoiceClient(VoiceProtocol):
.. versionadded:: 1.4
"""
ws = self.ws
ws = self._connection.ws
return float("inf") if not ws else ws.latency
@property
@ -451,72 +324,19 @@ class VoiceClient(VoiceProtocol):
.. versionadded:: 1.4
"""
ws = self.ws
ws = self._connection.ws
return float("inf") if not ws else ws.average_latency
async def poll_voice_ws(self, reconnect: bool) -> None:
backoff = ExponentialBackoff()
while True:
try:
await self.ws.poll_event()
except (ConnectionClosed, asyncio.TimeoutError) as exc:
if isinstance(exc, ConnectionClosed):
# The following close codes are undocumented so I will document them here.
# 1000 - normal closure (obviously)
# 4014 - voice channel has been deleted.
# 4015 - voice server has crashed
if exc.code in (1000, 4015):
_log.info('Disconnecting from voice normally, close code %d.', exc.code)
await self.disconnect()
break
if exc.code == 4014:
_log.info('Disconnected from voice by force... potentially reconnecting.')
successful = await self.potential_reconnect()
if not successful:
_log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
await self.disconnect()
break
else:
continue
if not reconnect:
await self.disconnect()
raise
retry = backoff.delay()
_log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
self._connected.clear()
await asyncio.sleep(retry)
await self.voice_disconnect()
try:
await self.connect(reconnect=True, timeout=self.timeout)
except asyncio.TimeoutError:
# at this point we've retried 5 times... let's continue the loop.
_log.warning('Could not connect to voice... Retrying...')
continue
async def disconnect(self, *, force: bool = False) -> None:
"""|coro|
Disconnects this voice client from voice.
"""
if not force and not self.is_connected():
return
self.stop()
self._connected.clear()
try:
if self.ws:
await self.ws.close()
await self.voice_disconnect()
finally:
self.cleanup()
if self.socket:
self.socket.close()
await self._connection.disconnect(force=force)
self.cleanup()
async def move_to(self, channel: Optional[abc.Snowflake]) -> None:
async def move_to(self, channel: Optional[abc.Snowflake], *, timeout: Optional[float] = 30.0) -> None:
"""|coro|
Moves you to a different voice channel.
@ -525,12 +345,22 @@ class VoiceClient(VoiceProtocol):
-----------
channel: Optional[:class:`abc.Snowflake`]
The channel to move to. Must be a voice channel.
timeout: Optional[:class:`float`]
How long to wait for the move to complete.
.. versionadded:: 2.4
Raises
-------
asyncio.TimeoutError
The move did not complete in time, but may still be ongoing.
"""
await self.channel.guild.change_voice_state(channel=channel)
await self._connection.move_to(channel)
await self._connection.wait_async(timeout)
def is_connected(self) -> bool:
"""Indicates if the voice client is connected to voice."""
return self._connected.is_set()
return self._connection.is_connected()
# audio related
@ -703,7 +533,7 @@ class VoiceClient(VoiceProtocol):
if self._player is None:
raise ValueError('Not playing anything.')
self._player._set_source(value)
self._player.set_source(value)
def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None:
"""Sends an audio packet composed of the data.
@ -732,8 +562,8 @@ class VoiceClient(VoiceProtocol):
encoded_data = data
packet = self._get_voice_packet(encoded_data)
try:
self.socket.sendto(packet, (self.endpoint_ip, self.voice_port))
except BlockingIOError:
_log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
self._connection.send_packet(packet)
except OSError:
_log.info('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295)

596
discord/voice_state.py

@ -0,0 +1,596 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
Some documentation to refer to:
- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID.
- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE.
- We pull the session_id from VOICE_STATE_UPDATE.
- We pull the token, endpoint and server_id from VOICE_SERVER_UPDATE.
- Then we initiate the voice web socket (vWS) pointing to the endpoint.
- We send opcode 0 with the user_id, server_id, session_id and token using the vWS.
- The vWS sends back opcode 2 with an ssrc, port, modes(array) and heartbeat_interval.
- We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE.
- Then we send our IP and port via vWS with opcode 1.
- When that's all done, we receive opcode 4 from the vWS.
- Finally we can transmit data to endpoint:port.
"""
from __future__ import annotations
import select
import socket
import asyncio
import logging
import threading
import async_timeout
from typing import TYPE_CHECKING, Optional, Dict, List, Callable, Coroutine, Any, Tuple
from .enums import Enum
from .utils import MISSING, sane_wait_for
from .errors import ConnectionClosed
from .backoff import ExponentialBackoff
from .gateway import DiscordVoiceWebSocket
if TYPE_CHECKING:
from . import abc
from .guild import Guild
from .user import ClientUser
from .member import VoiceState
from .voice_client import VoiceClient
from .types.voice import (
GuildVoiceState as GuildVoiceStatePayload,
VoiceServerUpdate as VoiceServerUpdatePayload,
SupportedModes,
)
WebsocketHook = Optional[Callable[['VoiceConnectionState', Dict[str, Any]], Coroutine[Any, Any, Any]]]
SocketReaderCallback = Callable[[bytes], Any]
__all__ = ('VoiceConnectionState',)
_log = logging.getLogger(__name__)
class SocketReader(threading.Thread):
def __init__(self, state: VoiceConnectionState) -> None:
super().__init__(daemon=True, name=f'voice-socket-reader:{id(self):#x}')
self.state: VoiceConnectionState = state
self._callbacks: List[SocketReaderCallback] = []
self._running = threading.Event()
self._end = threading.Event()
# If we have paused reading due to having no callbacks
self._idle_paused: bool = True
def register(self, callback: SocketReaderCallback) -> None:
self._callbacks.append(callback)
if self._idle_paused:
self._idle_paused = False
self._running.set()
def unregister(self, callback: SocketReaderCallback) -> None:
try:
self._callbacks.remove(callback)
except ValueError:
pass
else:
if not self._callbacks and self._running.is_set():
# If running is not set, we are either explicitly paused and
# should be explicitly resumed, or we are already idle paused
self._idle_paused = True
self._running.clear()
def pause(self) -> None:
self._idle_paused = False
self._running.clear()
def resume(self, *, force: bool = False) -> None:
if self._running.is_set():
return
# Don't resume if there are no callbacks registered
if not force and not self._callbacks:
# We tried to resume but there was nothing to do, so resume when ready
self._idle_paused = True
return
self._idle_paused = False
self._running.set()
def stop(self) -> None:
self._end.set()
self._running.set()
def run(self) -> None:
self._end.clear()
self._running.set()
try:
self._do_run()
except Exception:
_log.exception('Error in %s', self)
finally:
self.stop()
self._running.clear()
self._callbacks.clear()
def _do_run(self) -> None:
while not self._end.is_set():
if not self._running.is_set():
self._running.wait()
continue
# Since this socket is a non blocking socket, select has to be used to wait on it for reading.
try:
readable, _, _ = select.select([self.state.socket], [], [], 30)
except (ValueError, TypeError):
# The socket is either closed or doesn't exist at the moment
continue
if not readable:
continue
try:
data = self.state.socket.recv(2048)
except OSError:
_log.debug('Error reading from socket in %s, this should be safe to ignore', self, exc_info=True)
else:
for cb in self._callbacks:
try:
cb(data)
except Exception:
_log.exception('Error calling %s in %s', cb, self)
class ConnectionFlowState(Enum):
"""Enum representing voice connection flow state."""
# fmt: off
disconnected = 0
set_guild_voice_state = 1
got_voice_state_update = 2
got_voice_server_update = 3
got_both_voice_updates = 4
websocket_connected = 5
got_websocket_ready = 6
got_ip_discovery = 7
connected = 8
# fmt: on
class VoiceConnectionState:
"""Represents the internal state of a voice connection."""
def __init__(self, voice_client: VoiceClient, *, hook: Optional[WebsocketHook] = None) -> None:
self.voice_client = voice_client
self.hook = hook
self.timeout: float = 30.0
self.reconnect: bool = True
self.self_deaf: bool = False
self.self_mute: bool = False
self.token: Optional[str] = None
self.session_id: Optional[str] = None
self.endpoint: Optional[str] = None
self.endpoint_ip: Optional[str] = None
self.server_id: Optional[int] = None
self.ip: Optional[str] = None
self.port: Optional[int] = None
self.voice_port: Optional[int] = None
self.secret_key: List[int] = MISSING
self.ssrc: int = MISSING
self.mode: SupportedModes = MISSING
self.socket: socket.socket = MISSING
self.ws: DiscordVoiceWebSocket = MISSING
self._state: ConnectionFlowState = ConnectionFlowState.disconnected
self._expecting_disconnect: bool = False
self._connected = threading.Event()
self._state_event = asyncio.Event()
self._runner: Optional[asyncio.Task] = None
self._connector: Optional[asyncio.Task] = None
self._socket_reader = SocketReader(self)
self._socket_reader.start()
@property
def state(self) -> ConnectionFlowState:
return self._state
@state.setter
def state(self, state: ConnectionFlowState) -> None:
if state is not self._state:
_log.debug('Connection state changed to %s', state.name)
self._state = state
self._state_event.set()
self._state_event.clear()
if state is ConnectionFlowState.connected:
self._connected.set()
else:
self._connected.clear()
@property
def guild(self) -> Guild:
return self.voice_client.guild
@property
def user(self) -> ClientUser:
return self.voice_client.user
@property
def supported_modes(self) -> Tuple[SupportedModes, ...]:
return self.voice_client.supported_modes
@property
def self_voice_state(self) -> Optional[VoiceState]:
return self.guild.me.voice
async def voice_state_update(self, data: GuildVoiceStatePayload) -> None:
channel_id = data['channel_id']
if channel_id is None:
# If we know we're going to get a voice_state_update where we have no channel due to
# being in the reconnect flow, we ignore it. Otherwise, it probably wasn't from us.
if self._expecting_disconnect:
self._expecting_disconnect = False
else:
_log.debug('We were externally disconnected from voice.')
await self.disconnect()
return
self.session_id = data['session_id']
# we got the event while connecting
if self.state in (ConnectionFlowState.set_guild_voice_state, ConnectionFlowState.got_voice_server_update):
if self.state is ConnectionFlowState.set_guild_voice_state:
self.state = ConnectionFlowState.got_voice_state_update
else:
self.state = ConnectionFlowState.got_both_voice_updates
return
if self.state is ConnectionFlowState.connected:
self.voice_client.channel = channel_id and self.guild.get_channel(int(channel_id)) # type: ignore
elif self.state is not ConnectionFlowState.disconnected:
if channel_id != self.voice_client.channel.id:
# For some unfortunate reason we were moved during the connection flow
_log.info('Handling channel move while connecting...')
self.voice_client.channel = channel_id and self.guild.get_channel(int(channel_id)) # type: ignore
await self.soft_disconnect(with_state=ConnectionFlowState.got_voice_state_update)
await self.connect(
reconnect=self.reconnect,
timeout=self.timeout,
self_deaf=(self.self_voice_state or self).self_deaf,
self_mute=(self.self_voice_state or self).self_mute,
resume=False,
wait=False,
)
else:
_log.debug('Ignoring unexpected voice_state_update event')
async def voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
self.token = data['token']
self.server_id = int(data['guild_id'])
endpoint = data.get('endpoint')
if self.token is None or endpoint is None:
_log.warning(
'Awaiting endpoint... This requires waiting. '
'If timeout occurred considering raising the timeout and reconnecting.'
)
return
self.endpoint, _, _ = endpoint.rpartition(':')
if self.endpoint.startswith('wss://'):
# Just in case, strip it off since we're going to add it later
self.endpoint = self.endpoint[6:]
# we got the event while connecting
if self.state in (ConnectionFlowState.set_guild_voice_state, ConnectionFlowState.got_voice_state_update):
# This gets set after READY is received
self.endpoint_ip = MISSING
self._create_socket()
if self.state is ConnectionFlowState.set_guild_voice_state:
self.state = ConnectionFlowState.got_voice_server_update
else:
self.state = ConnectionFlowState.got_both_voice_updates
elif self.state is ConnectionFlowState.connected:
_log.debug('Voice server update, closing old voice websocket')
await self.ws.close(4014)
self.state = ConnectionFlowState.got_voice_server_update
elif self.state is not ConnectionFlowState.disconnected:
_log.debug('Unexpected server update event, attempting to handle')
await self.soft_disconnect(with_state=ConnectionFlowState.got_voice_server_update)
await self.connect(
reconnect=self.reconnect,
timeout=self.timeout,
self_deaf=(self.self_voice_state or self).self_deaf,
self_mute=(self.self_voice_state or self).self_mute,
resume=False,
wait=False,
)
self._create_socket()
async def connect(
self, *, reconnect: bool, timeout: float, self_deaf: bool, self_mute: bool, resume: bool, wait: bool = True
) -> None:
if self._connector:
self._connector.cancel()
self._connector = None
if self._runner:
self._runner.cancel()
self._runner = None
self.timeout = timeout
self.reconnect = reconnect
self._connector = self.voice_client.loop.create_task(
self._wrap_connect(reconnect, timeout, self_deaf, self_mute, resume), name='Voice connector'
)
if wait:
await self._connector
async def _wrap_connect(self, *args: Any) -> None:
try:
await self._connect(*args)
except asyncio.CancelledError:
_log.debug('Cancelling voice connection')
await self.soft_disconnect()
raise
except asyncio.TimeoutError:
_log.info('Timed out connecting to voice')
await self.disconnect()
raise
except Exception:
_log.exception('Error connecting to voice... disconnecting')
await self.disconnect()
raise
async def _connect(self, reconnect: bool, timeout: float, self_deaf: bool, self_mute: bool, resume: bool) -> None:
_log.info('Connecting to voice...')
async with async_timeout.timeout(timeout):
for i in range(5):
_log.info('Starting voice handshake... (connection attempt %d)', i + 1)
await self._voice_connect(self_deaf=self_deaf, self_mute=self_mute)
# Setting this unnecessarily will break reconnecting
if self.state is ConnectionFlowState.disconnected:
self.state = ConnectionFlowState.set_guild_voice_state
await self._wait_for_state(ConnectionFlowState.got_both_voice_updates)
_log.info('Voice handshake complete. Endpoint found: %s', self.endpoint)
try:
self.ws = await self._connect_websocket(resume)
await self._handshake_websocket()
break
except ConnectionClosed:
if reconnect:
wait = 1 + i * 2.0
_log.exception('Failed to connect to voice... Retrying in %ss...', wait)
await self.disconnect(cleanup=False)
await asyncio.sleep(wait)
continue
else:
await self.disconnect()
raise
_log.info('Voice connection complete.')
if not self._runner:
self._runner = self.voice_client.loop.create_task(self._poll_voice_ws(reconnect), name='Voice websocket poller')
async def disconnect(self, *, force: bool = True, cleanup: bool = True) -> None:
if not force and not self.is_connected():
return
try:
if self.ws:
await self.ws.close()
await self._voice_disconnect()
except Exception:
_log.debug('Ignoring exception disconnecting from voice', exc_info=True)
finally:
self.ip = MISSING
self.port = MISSING
self.state = ConnectionFlowState.disconnected
self._socket_reader.pause()
# Flip the connected event to unlock any waiters
self._connected.set()
self._connected.clear()
if cleanup:
self._socket_reader.stop()
self.voice_client.cleanup()
if self.socket:
self.socket.close()
async def soft_disconnect(self, *, with_state: ConnectionFlowState = ConnectionFlowState.got_both_voice_updates) -> None:
_log.debug('Soft disconnecting from voice')
# Stop the websocket reader because closing the websocket will trigger an unwanted reconnect
if self._runner:
self._runner.cancel()
self._runner = None
try:
if self.ws:
await self.ws.close()
except Exception:
_log.debug('Ignoring exception soft disconnecting from voice', exc_info=True)
finally:
self.ip = MISSING
self.port = MISSING
self.state = with_state
self._socket_reader.pause()
if self.socket:
self.socket.close()
async def move_to(self, channel: Optional[abc.Snowflake]) -> None:
if channel is None:
await self.disconnect()
return
await self.voice_client.channel.guild.change_voice_state(channel=channel)
self.state = ConnectionFlowState.set_guild_voice_state
def wait(self, timeout: Optional[float] = None) -> bool:
return self._connected.wait(timeout)
async def wait_async(self, timeout: Optional[float] = None) -> None:
await self._wait_for_state(ConnectionFlowState.connected, timeout=timeout)
def is_connected(self) -> bool:
return self.state is ConnectionFlowState.connected
def send_packet(self, packet: bytes) -> None:
self.socket.sendall(packet)
def add_socket_listener(self, callback: SocketReaderCallback) -> None:
_log.debug('Registering socket listener callback %s', callback)
self._socket_reader.register(callback)
def remove_socket_listener(self, callback: SocketReaderCallback) -> None:
_log.debug('Unregistering socket listener callback %s', callback)
self._socket_reader.unregister(callback)
async def _wait_for_state(
self, state: ConnectionFlowState, *other_states: ConnectionFlowState, timeout: Optional[float] = None
) -> None:
states = (state, *other_states)
while True:
if self.state in states:
return
await sane_wait_for([self._state_event.wait()], timeout=timeout)
async def _voice_connect(self, *, self_deaf: bool = False, self_mute: bool = False) -> None:
channel = self.voice_client.channel
await channel.guild.change_voice_state(channel=channel, self_deaf=self_deaf, self_mute=self_mute)
async def _voice_disconnect(self) -> None:
_log.info(
'The voice handshake is being terminated for Channel ID %s (Guild ID %s)',
self.voice_client.channel.id,
self.voice_client.guild.id,
)
self.state = ConnectionFlowState.disconnected
await self.voice_client.channel.guild.change_voice_state(channel=None)
self._expecting_disconnect = True
async def _connect_websocket(self, resume: bool) -> DiscordVoiceWebSocket:
ws = await DiscordVoiceWebSocket.from_connection_state(self, resume=resume, hook=self.hook)
self.state = ConnectionFlowState.websocket_connected
return ws
async def _handshake_websocket(self) -> None:
while not self.ip:
await self.ws.poll_event()
self.state = ConnectionFlowState.got_ip_discovery
while self.ws.secret_key is None:
await self.ws.poll_event()
self.state = ConnectionFlowState.connected
def _create_socket(self) -> None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.setblocking(False)
self._socket_reader.resume()
async def _poll_voice_ws(self, reconnect: bool) -> None:
backoff = ExponentialBackoff()
while True:
try:
await self.ws.poll_event()
except asyncio.CancelledError:
return
except (ConnectionClosed, asyncio.TimeoutError) as exc:
if isinstance(exc, ConnectionClosed):
# The following close codes are undocumented so I will document them here.
# 1000 - normal closure (obviously)
# 4014 - we were externally disconnected (voice channel deleted, we were moved, etc)
# 4015 - voice server has crashed
if exc.code in (1000, 4015):
_log.info('Disconnecting from voice normally, close code %d.', exc.code)
await self.disconnect()
break
if exc.code == 4014:
_log.info('Disconnected from voice by force... potentially reconnecting.')
successful = await self._potential_reconnect()
if not successful:
_log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
await self.disconnect()
break
else:
continue
_log.debug('Not handling close code %s (%s)', exc.code, exc.reason or 'no reason')
if not reconnect:
await self.disconnect()
raise
retry = backoff.delay()
_log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
await asyncio.sleep(retry)
await self.disconnect(cleanup=False)
try:
await self._connect(
reconnect=reconnect,
timeout=self.timeout,
self_deaf=(self.self_voice_state or self).self_deaf,
self_mute=(self.self_voice_state or self).self_mute,
resume=False,
)
except asyncio.TimeoutError:
# at this point we've retried 5 times... let's continue the loop.
_log.warning('Could not connect to voice... Retrying...')
continue
async def _potential_reconnect(self) -> bool:
try:
await self._wait_for_state(
ConnectionFlowState.got_voice_server_update, ConnectionFlowState.got_both_voice_updates, timeout=self.timeout
)
except asyncio.TimeoutError:
return False
try:
self.ws = await self._connect_websocket(False)
await self._handshake_websocket()
except (ConnectionClosed, asyncio.TimeoutError):
return False
else:
return True
Loading…
Cancel
Save