From 0b93fa3a82e8b1b1d9f637e7be0333efd0a232b2 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Mon, 10 Aug 2020 06:28:36 -0400 Subject: [PATCH] Implement VoiceProtocol lower level hooks. This allows changing the connect flow and taking control of it without relying on internal events or tricks. --- discord/__init__.py | 2 +- discord/abc.py | 19 ++- discord/client.py | 6 +- discord/ext/commands/context.py | 2 +- discord/guild.py | 2 +- discord/shard.py | 1 + discord/state.py | 14 +- discord/voice_client.py | 287 ++++++++++++++++++++++---------- docs/api.rst | 3 + 9 files changed, 230 insertions(+), 106 deletions(-) diff --git a/discord/__init__.py b/discord/__init__.py index 78b25e31d..c6b215939 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -54,7 +54,7 @@ from .mentions import AllowedMentions from .shard import AutoShardedClient, ShardInfo from .player import * from .webhook import * -from .voice_client import VoiceClient +from .voice_client import VoiceClient, VoiceProtocol from .audit_logs import AuditLogChanges, AuditLogEntry, AuditLogDiff from .raw_models import * from .team import * diff --git a/discord/abc.py b/discord/abc.py index 4024334d1..75624e18f 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -36,7 +36,7 @@ from .permissions import PermissionOverwrite, Permissions from .role import Role from .invite import Invite from .file import File -from .voice_client import VoiceClient +from .voice_client import VoiceClient, VoiceProtocol from . import utils class _Undefined: @@ -1053,7 +1053,6 @@ class Messageable(metaclass=abc.ABCMeta): """ return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first) - class Connectable(metaclass=abc.ABCMeta): """An ABC that details the common operations on a channel that can connect to a voice server. @@ -1072,7 +1071,7 @@ class Connectable(metaclass=abc.ABCMeta): def _get_voice_state_pair(self): raise NotImplementedError - async def connect(self, *, timeout=60.0, reconnect=True): + async def connect(self, *, timeout=60.0, reconnect=True, cls=VoiceClient): """|coro| Connects to voice and creates a :class:`VoiceClient` to establish @@ -1086,6 +1085,9 @@ class Connectable(metaclass=abc.ABCMeta): Whether the bot should automatically attempt a reconnect if a part of the handshake fails or the gateway goes down. + cls: Type[:class:`VoiceProtocol`] + A type that subclasses :class:`~discord.VoiceProtocol` to connect with. + Defaults to :class:`~discord.VoiceClient`. Raises ------- @@ -1098,20 +1100,25 @@ class Connectable(metaclass=abc.ABCMeta): Returns -------- - :class:`~discord.VoiceClient` + :class:`~discord.VoiceProtocol` A voice client that is fully connected to the voice server. """ + + if not issubclass(cls, VoiceProtocol): + raise TypeError('Type must meet VoiceProtocol abstract base class.') + key_id, _ = self._get_voice_client_key() state = self._state if state._get_voice_client(key_id): raise ClientException('Already connected to a voice channel.') - voice = VoiceClient(state=state, timeout=timeout, channel=self) + client = state._get_client() + voice = cls(client, self) state._add_voice_client(key_id, voice) try: - await voice.connect(reconnect=reconnect) + await voice.connect(timeout=timeout, reconnect=reconnect) except asyncio.TimeoutError: try: await voice.disconnect(force=True) diff --git a/discord/client.py b/discord/client.py index 9bc5dd120..407fd47f9 100644 --- a/discord/client.py +++ b/discord/client.py @@ -238,6 +238,7 @@ class Client: self._closed = False self._ready = asyncio.Event() self._connection._get_websocket = self._get_websocket + self._connection._get_client = lambda: self if VoiceClient.warn_nacl: VoiceClient.warn_nacl = False @@ -299,7 +300,10 @@ class Client: @property def voice_clients(self): - """List[:class:`.VoiceClient`]: Represents a list of voice connections.""" + """List[:class:`.VoiceProtocol`]: Represents a list of voice connections. + + These are usually :class:`.VoiceClient` instances. + """ return self._connection.voice_clients def is_ready(self): diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 8b8cf4bcb..3cf851c68 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -238,7 +238,7 @@ class Context(discord.abc.Messageable): @property def voice_client(self): - r"""Optional[:class:`.VoiceClient`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" + r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" g = self.guild return g.voice_client if g else None diff --git a/discord/guild.py b/discord/guild.py index 4c6013a32..0bf94a280 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -377,7 +377,7 @@ class Guild(Hashable): @property def voice_client(self): - """Optional[:class:`VoiceClient`]: Returns the :class:`VoiceClient` associated with this guild, if any.""" + """Optional[:class:`VoiceProtocol`]: Returns the :class:`VoiceProtocol` associated with this guild, if any.""" return self._state._get_voice_client(self.id) @property diff --git a/discord/shard.py b/discord/shard.py index f63206786..ef29d5907 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -292,6 +292,7 @@ class AutoShardedClient(Client): # the key is the shard_id self.__shards = {} self._connection._get_websocket = self._get_websocket + self._connection._get_client = lambda: self self.__queue = asyncio.PriorityQueue() def _get_websocket(self, guild_id=None, *, shard_id=None): diff --git a/discord/state.py b/discord/state.py index f0e93d352..fc297d035 100644 --- a/discord/state.py +++ b/discord/state.py @@ -63,6 +63,12 @@ Listener = namedtuple('Listener', ('type', 'future', 'predicate')) log = logging.getLogger(__name__) ReadyState = namedtuple('ReadyState', ('launch', 'guilds')) +async def logging_coroutine(coroutine, *, info): + try: + await coroutine + except Exception: + log.exception('Exception occurred during %s', info) + class ConnectionState: def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options): self.loop = loop @@ -939,9 +945,8 @@ class ConnectionState: if int(data['user_id']) == self.user.id: voice = self._get_voice_client(guild.id) if voice is not None: - ch = guild.get_channel(channel_id) - if ch is not None: - voice.channel = ch + coro = voice.on_voice_state_update(data) + asyncio.ensure_future(logging_coroutine(coro, info='Voice Protocol voice state update handler')) member, before, after = guild._update_voice_state(data, channel_id) if member is not None: @@ -962,7 +967,8 @@ class ConnectionState: vc = self._get_voice_client(key_id) if vc is not None: - asyncio.ensure_future(vc._create_socket(key_id, data)) + coro = vc.on_voice_server_update(data) + asyncio.ensure_future(logging_coroutine(coro, info='Voice Protocol voice server update handler')) def parse_typing_start(self, data): channel, guild = self._get_guild_channel(data) diff --git a/discord/voice_client.py b/discord/voice_client.py index ab9a6406c..a1a7109a4 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -45,7 +45,7 @@ import logging import struct import threading -from . import opus +from . import opus, utils from .backoff import ExponentialBackoff from .gateway import * from .errors import ClientException, ConnectionClosed @@ -59,7 +59,110 @@ except ImportError: log = logging.getLogger(__name__) -class VoiceClient: +class VoiceProtocol: + """A class that represents the Discord voice protocol. + + This is an abstract class. The library provides a concrete implementation + under :class:`VoiceClient`. + + This class allows you to implement a protocol to allow for an external + method of sending voice, such as Lavalink_ or a native library implementation. + + These classes are passed to :meth:`abc.Connectable.connect`. + + .. _Lavalink: https://github.com/Frederikam/Lavalink + + Parameters + ------------ + client: :class:`Client` + The client (or its subclasses) that started the connection request. + channel: :class:`abc.Connectable` + The voice channel that is being connected to. + """ + + def __init__(self, client, channel): + self.client = client + self.channel = channel + + async def on_voice_state_update(self, data): + """|coro| + + An abstract method that is called when the client's voice state + has changed. This corresponds to ``VOICE_STATE_UPDATE``. + + Parameters + ------------ + data: :class:`dict` + The raw `voice state payload`_. + + .. _voice state payload: https://discord.com/developers/docs/resources/voice#voice-state-object + """ + raise NotImplementedError + + async def on_voice_server_update(self, data): + """|coro| + + An abstract method that is called when initially connecting to voice. + This corresponds to ``VOICE_SERVER_UPDATE``. + + Parameters + ------------ + data: :class:`dict` + The raw `voice server update payload`__. + + .. _VSU: https://discord.com/developers/docs/topics/gateway#voice-server-update-voice-server-update-event-fields + + __ VSU_ + """ + raise NotImplementedError + + async def connect(self, *, timeout, reconnect): + """|coro| + + An abstract method called when the client initiates the connection request. + + When a connection is requested initially, the library calls the following functions + in order: + + - ``__init__`` + + Parameters + ------------ + timeout: :class:`float` + The timeout for the connection. + reconnect: :class:`bool` + Whether reconnection is expected. + """ + raise NotImplementedError + + async def disconnect(self, *, force): + """|coro| + + An abstract method called when the client terminates the connection. + + See :meth:`cleanup`. + + Parameters + ------------ + force: :class:`bool` + Whether the disconnection was forced. + """ + raise NotImplementedError + + def cleanup(self): + """This method *must* be called to ensure proper clean-up during a disconnect. + + It is advisable to call this from within :meth:`disconnect` when you are + completely done with the voice protocol instance. + + This method removes it from the internal state cache that keeps track of + currently alive voice clients. Failure to clean-up will cause subsequent + connections to report that it's still connected. + """ + key_id, _ = self.channel._get_voice_client_key() + self.client._connection._remove_voice_client(key_id) + +class VoiceClient(VoiceProtocol): """Represents a Discord voice connection. You do not create these, you typically get them from @@ -85,14 +188,13 @@ class VoiceClient: loop: :class:`asyncio.AbstractEventLoop` The event loop that the voice client is running on. """ - def __init__(self, state, timeout, channel): + def __init__(self, client, channel): if not has_nacl: raise RuntimeError("PyNaCl library needed in order to use voice") - self.channel = channel - self.main_ws = None - self.timeout = timeout - self.ws = None + super().__init__(client, channel) + state = client._connection + self.token = None self.socket = None self.loop = state.loop self._state = state @@ -100,8 +202,8 @@ class VoiceClient: self._connected = threading.Event() self._handshaking = False - self._handshake_check = asyncio.Lock() - self._handshake_complete = asyncio.Event() + self._voice_state_complete = asyncio.Event() + self._voice_server_complete = asyncio.Event() self.mode = None self._connections = 0 @@ -138,48 +240,24 @@ class VoiceClient: # connection related - async def start_handshake(self): - log.info('Starting voice handshake...') - - guild_id, channel_id = self.channel._get_voice_state_pair() - state = self._state - self.main_ws = ws = state._get_websocket(guild_id) - self._connections += 1 - - # request joining - await ws.voice_state(guild_id, channel_id) - - try: - await asyncio.wait_for(self._handshake_complete.wait(), timeout=self.timeout) - except asyncio.TimeoutError: - await self.terminate_handshake(remove=True) - raise - - log.info('Voice handshake complete. Endpoint found %s (IP: %s)', self.endpoint, self.endpoint_ip) + async def on_voice_state_update(self, data): + self.session_id = data['session_id'] + channel_id = data['channel_id'] - async def terminate_handshake(self, *, remove=False): - guild_id, channel_id = self.channel._get_voice_state_pair() - self._handshake_complete.clear() - await self.main_ws.voice_state(guild_id, None, self_mute=True) - self._handshaking = False + if not self._handshaking: + # If we're done handshaking then we just need to update ourselves + guild = self.guild + self.channel = channel_id and guild and guild.get_channel(int(channel_id)) + else: + self._voice_state_complete.set() - log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', channel_id, guild_id) - if remove: - log.info('The voice client has been removed for Channel ID %s (Guild ID %s)', channel_id, guild_id) - key_id, _ = self.channel._get_voice_client_key() - self._state._remove_voice_client(key_id) - - async def _create_socket(self, server_id, data): - async with self._handshake_check: - if self._handshaking: - log.info("Ignoring voice server update while handshake is in progress") - return - self._handshaking = True + async def on_voice_server_update(self, data): + if self._voice_server_complete.is_set(): + log.info('Ignoring extraneous voice server update.') + return - self._connected.clear() - self.session_id = self.main_ws.session_id - self.server_id = server_id self.token = data.get('token') + self.server_id = int(data['guild_id']) endpoint = data.get('endpoint') if endpoint is None or self.token is None: @@ -195,23 +273,77 @@ class VoiceClient: # This gets set later self.endpoint_ip = None - if self.socket: - try: - self.socket.close() - except Exception: - pass - self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket.setblocking(False) - if self._handshake_complete.is_set(): - # terminate the websocket and handle the reconnect loop if necessary. - self._handshake_complete.clear() - self._handshaking = False + if not self._handshaking: + # If we're not handshaking then we need to terminate our previous connection in the websocket await self.ws.close(4000) return - self._handshake_complete.set() + self._voice_server_complete.set() + + async def voice_connect(self): + self._connections += 1 + await self.channel.guild.change_voice_state(channel=self.channel) + + async def voice_disconnect(self): + log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id) + await self.channel.guild.change_voice_state(channel=None) + + async def connect(self, *, reconnect, timeout): + log.info('Connecting to voice...') + self.timeout = timeout + try: + del self.secret_key + except AttributeError: + pass + + + for i in range(5): + self._voice_state_complete.clear() + self._voice_server_complete.clear() + self._handshaking = True + + # This has to be created before we start the flow. + futures = [ + self._voice_state_complete.wait(), + self._voice_server_complete.wait(), + ] + + # Start the connection flow + log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1) + await self.voice_connect() + + try: + await utils.sane_wait_for(futures, timeout=timeout) + except asyncio.TimeoutError: + await self.disconnect(force=True) + raise + + log.info('Voice handshake complete. Endpoint found %s', self.endpoint) + self._handshaking = False + self._voice_server_complete.clear() + self._voice_state_complete.clear() + + try: + self.ws = await DiscordVoiceWebSocket.from_client(self) + self._connected.clear() + while not hasattr(self, 'secret_key'): + await self.ws.poll_event() + self._connected.set() + break + except (ConnectionClosed, asyncio.TimeoutError): + if reconnect: + log.exception('Failed to connect to voice... Retrying...') + await asyncio.sleep(1 + i * 2.0) + await self.voice_disconnect() + continue + else: + raise + + if self._runner is None: + self._runner = self.loop.create_task(self.poll_voice_ws(reconnect)) @property def latency(self): @@ -234,35 +366,6 @@ class VoiceClient: ws = self.ws return float("inf") if not ws else ws.average_latency - async def connect(self, *, reconnect=True, _tries=0, do_handshake=True): - log.info('Connecting to voice...') - try: - del self.secret_key - except AttributeError: - pass - - if do_handshake: - await self.start_handshake() - - try: - self.ws = await DiscordVoiceWebSocket.from_client(self) - self._handshaking = False - self._connected.clear() - while not hasattr(self, 'secret_key'): - await self.ws.poll_event() - self._connected.set() - except (ConnectionClosed, asyncio.TimeoutError): - if reconnect and _tries < 5: - log.exception('Failed to connect to voice... Retrying...') - await asyncio.sleep(1 + _tries * 2.0) - await self.terminate_handshake() - await self.connect(reconnect=reconnect, _tries=_tries + 1) - else: - raise - - if self._runner is None: - self._runner = self.loop.create_task(self.poll_voice_ws(reconnect)) - async def poll_voice_ws(self, reconnect): backoff = ExponentialBackoff() while True: @@ -287,9 +390,9 @@ class VoiceClient: log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry) self._connected.clear() await asyncio.sleep(retry) - await self.terminate_handshake() + await self.voice_disconnect() try: - await self.connect(reconnect=True) + await self.connect(reconnect=True, timeout=self.timeout) except asyncio.TimeoutError: # at this point we've retried 5 times... let's continue the loop. log.warning('Could not connect to voice... Retrying...') @@ -310,8 +413,9 @@ class VoiceClient: if self.ws: await self.ws.close() - await self.terminate_handshake(remove=True) + await self.voice_disconnect() finally: + self.cleanup() if self.socket: self.socket.close() @@ -325,8 +429,7 @@ class VoiceClient: channel: :class:`abc.Snowflake` The channel to move to. Must be a voice channel. """ - guild_id, _ = self.channel._get_voice_state_pair() - await self.main_ws.voice_state(guild_id, channel.id) + await self.channel.guild.change_voice_state(channel=channel) def is_connected(self): """Indicates if the voice client is connected to voice.""" diff --git a/docs/api.rst b/docs/api.rst index 6b843bd1f..d4af5ff19 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -54,6 +54,9 @@ Voice .. autoclass:: VoiceClient() :members: +.. autoclass:: VoiceProtocol + :members: + .. autoclass:: AudioSource :members: