Browse Source

Implement VoiceProtocol lower level hooks.

This allows changing the connect flow and taking control of it without
relying on internal events or tricks.
pull/5849/head
Rapptz 5 years ago
parent
commit
0b93fa3a82
  1. 2
      discord/__init__.py
  2. 19
      discord/abc.py
  3. 6
      discord/client.py
  4. 2
      discord/ext/commands/context.py
  5. 2
      discord/guild.py
  6. 1
      discord/shard.py
  7. 14
      discord/state.py
  8. 287
      discord/voice_client.py
  9. 3
      docs/api.rst

2
discord/__init__.py

@ -54,7 +54,7 @@ from .mentions import AllowedMentions
from .shard import AutoShardedClient, ShardInfo from .shard import AutoShardedClient, ShardInfo
from .player import * from .player import *
from .webhook import * from .webhook import *
from .voice_client import VoiceClient from .voice_client import VoiceClient, VoiceProtocol
from .audit_logs import AuditLogChanges, AuditLogEntry, AuditLogDiff from .audit_logs import AuditLogChanges, AuditLogEntry, AuditLogDiff
from .raw_models import * from .raw_models import *
from .team import * from .team import *

19
discord/abc.py

@ -36,7 +36,7 @@ from .permissions import PermissionOverwrite, Permissions
from .role import Role from .role import Role
from .invite import Invite from .invite import Invite
from .file import File from .file import File
from .voice_client import VoiceClient from .voice_client import VoiceClient, VoiceProtocol
from . import utils from . import utils
class _Undefined: class _Undefined:
@ -1053,7 +1053,6 @@ class Messageable(metaclass=abc.ABCMeta):
""" """
return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first) return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first)
class Connectable(metaclass=abc.ABCMeta): class Connectable(metaclass=abc.ABCMeta):
"""An ABC that details the common operations on a channel that can """An ABC that details the common operations on a channel that can
connect to a voice server. connect to a voice server.
@ -1072,7 +1071,7 @@ class Connectable(metaclass=abc.ABCMeta):
def _get_voice_state_pair(self): def _get_voice_state_pair(self):
raise NotImplementedError raise NotImplementedError
async def connect(self, *, timeout=60.0, reconnect=True): async def connect(self, *, timeout=60.0, reconnect=True, cls=VoiceClient):
"""|coro| """|coro|
Connects to voice and creates a :class:`VoiceClient` to establish Connects to voice and creates a :class:`VoiceClient` to establish
@ -1086,6 +1085,9 @@ class Connectable(metaclass=abc.ABCMeta):
Whether the bot should automatically attempt Whether the bot should automatically attempt
a reconnect if a part of the handshake fails a reconnect if a part of the handshake fails
or the gateway goes down. or the gateway goes down.
cls: Type[:class:`VoiceProtocol`]
A type that subclasses :class:`~discord.VoiceProtocol` to connect with.
Defaults to :class:`~discord.VoiceClient`.
Raises Raises
------- -------
@ -1098,20 +1100,25 @@ class Connectable(metaclass=abc.ABCMeta):
Returns Returns
-------- --------
:class:`~discord.VoiceClient` :class:`~discord.VoiceProtocol`
A voice client that is fully connected to the voice server. A voice client that is fully connected to the voice server.
""" """
if not issubclass(cls, VoiceProtocol):
raise TypeError('Type must meet VoiceProtocol abstract base class.')
key_id, _ = self._get_voice_client_key() key_id, _ = self._get_voice_client_key()
state = self._state state = self._state
if state._get_voice_client(key_id): if state._get_voice_client(key_id):
raise ClientException('Already connected to a voice channel.') raise ClientException('Already connected to a voice channel.')
voice = VoiceClient(state=state, timeout=timeout, channel=self) client = state._get_client()
voice = cls(client, self)
state._add_voice_client(key_id, voice) state._add_voice_client(key_id, voice)
try: try:
await voice.connect(reconnect=reconnect) await voice.connect(timeout=timeout, reconnect=reconnect)
except asyncio.TimeoutError: except asyncio.TimeoutError:
try: try:
await voice.disconnect(force=True) await voice.disconnect(force=True)

6
discord/client.py

@ -238,6 +238,7 @@ class Client:
self._closed = False self._closed = False
self._ready = asyncio.Event() self._ready = asyncio.Event()
self._connection._get_websocket = self._get_websocket self._connection._get_websocket = self._get_websocket
self._connection._get_client = lambda: self
if VoiceClient.warn_nacl: if VoiceClient.warn_nacl:
VoiceClient.warn_nacl = False VoiceClient.warn_nacl = False
@ -299,7 +300,10 @@ class Client:
@property @property
def voice_clients(self): def voice_clients(self):
"""List[:class:`.VoiceClient`]: Represents a list of voice connections.""" """List[:class:`.VoiceProtocol`]: Represents a list of voice connections.
These are usually :class:`.VoiceClient` instances.
"""
return self._connection.voice_clients return self._connection.voice_clients
def is_ready(self): def is_ready(self):

2
discord/ext/commands/context.py

@ -238,7 +238,7 @@ class Context(discord.abc.Messageable):
@property @property
def voice_client(self): def voice_client(self):
r"""Optional[:class:`.VoiceClient`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable."""
g = self.guild g = self.guild
return g.voice_client if g else None return g.voice_client if g else None

2
discord/guild.py

@ -377,7 +377,7 @@ class Guild(Hashable):
@property @property
def voice_client(self): def voice_client(self):
"""Optional[:class:`VoiceClient`]: Returns the :class:`VoiceClient` associated with this guild, if any.""" """Optional[:class:`VoiceProtocol`]: Returns the :class:`VoiceProtocol` associated with this guild, if any."""
return self._state._get_voice_client(self.id) return self._state._get_voice_client(self.id)
@property @property

1
discord/shard.py

@ -292,6 +292,7 @@ class AutoShardedClient(Client):
# the key is the shard_id # the key is the shard_id
self.__shards = {} self.__shards = {}
self._connection._get_websocket = self._get_websocket self._connection._get_websocket = self._get_websocket
self._connection._get_client = lambda: self
self.__queue = asyncio.PriorityQueue() self.__queue = asyncio.PriorityQueue()
def _get_websocket(self, guild_id=None, *, shard_id=None): def _get_websocket(self, guild_id=None, *, shard_id=None):

14
discord/state.py

@ -63,6 +63,12 @@ Listener = namedtuple('Listener', ('type', 'future', 'predicate'))
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ReadyState = namedtuple('ReadyState', ('launch', 'guilds')) ReadyState = namedtuple('ReadyState', ('launch', 'guilds'))
async def logging_coroutine(coroutine, *, info):
try:
await coroutine
except Exception:
log.exception('Exception occurred during %s', info)
class ConnectionState: class ConnectionState:
def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options): def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options):
self.loop = loop self.loop = loop
@ -939,9 +945,8 @@ class ConnectionState:
if int(data['user_id']) == self.user.id: if int(data['user_id']) == self.user.id:
voice = self._get_voice_client(guild.id) voice = self._get_voice_client(guild.id)
if voice is not None: if voice is not None:
ch = guild.get_channel(channel_id) coro = voice.on_voice_state_update(data)
if ch is not None: asyncio.ensure_future(logging_coroutine(coro, info='Voice Protocol voice state update handler'))
voice.channel = ch
member, before, after = guild._update_voice_state(data, channel_id) member, before, after = guild._update_voice_state(data, channel_id)
if member is not None: if member is not None:
@ -962,7 +967,8 @@ class ConnectionState:
vc = self._get_voice_client(key_id) vc = self._get_voice_client(key_id)
if vc is not None: if vc is not None:
asyncio.ensure_future(vc._create_socket(key_id, data)) coro = vc.on_voice_server_update(data)
asyncio.ensure_future(logging_coroutine(coro, info='Voice Protocol voice server update handler'))
def parse_typing_start(self, data): def parse_typing_start(self, data):
channel, guild = self._get_guild_channel(data) channel, guild = self._get_guild_channel(data)

287
discord/voice_client.py

@ -45,7 +45,7 @@ import logging
import struct import struct
import threading import threading
from . import opus from . import opus, utils
from .backoff import ExponentialBackoff from .backoff import ExponentialBackoff
from .gateway import * from .gateway import *
from .errors import ClientException, ConnectionClosed from .errors import ClientException, ConnectionClosed
@ -59,7 +59,110 @@ except ImportError:
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class VoiceClient: class VoiceProtocol:
"""A class that represents the Discord voice protocol.
This is an abstract class. The library provides a concrete implementation
under :class:`VoiceClient`.
This class allows you to implement a protocol to allow for an external
method of sending voice, such as Lavalink_ or a native library implementation.
These classes are passed to :meth:`abc.Connectable.connect`.
.. _Lavalink: https://github.com/Frederikam/Lavalink
Parameters
------------
client: :class:`Client`
The client (or its subclasses) that started the connection request.
channel: :class:`abc.Connectable`
The voice channel that is being connected to.
"""
def __init__(self, client, channel):
self.client = client
self.channel = channel
async def on_voice_state_update(self, data):
"""|coro|
An abstract method that is called when the client's voice state
has changed. This corresponds to ``VOICE_STATE_UPDATE``.
Parameters
------------
data: :class:`dict`
The raw `voice state payload`_.
.. _voice state payload: https://discord.com/developers/docs/resources/voice#voice-state-object
"""
raise NotImplementedError
async def on_voice_server_update(self, data):
"""|coro|
An abstract method that is called when initially connecting to voice.
This corresponds to ``VOICE_SERVER_UPDATE``.
Parameters
------------
data: :class:`dict`
The raw `voice server update payload`__.
.. _VSU: https://discord.com/developers/docs/topics/gateway#voice-server-update-voice-server-update-event-fields
__ VSU_
"""
raise NotImplementedError
async def connect(self, *, timeout, reconnect):
"""|coro|
An abstract method called when the client initiates the connection request.
When a connection is requested initially, the library calls the following functions
in order:
- ``__init__``
Parameters
------------
timeout: :class:`float`
The timeout for the connection.
reconnect: :class:`bool`
Whether reconnection is expected.
"""
raise NotImplementedError
async def disconnect(self, *, force):
"""|coro|
An abstract method called when the client terminates the connection.
See :meth:`cleanup`.
Parameters
------------
force: :class:`bool`
Whether the disconnection was forced.
"""
raise NotImplementedError
def cleanup(self):
"""This method *must* be called to ensure proper clean-up during a disconnect.
It is advisable to call this from within :meth:`disconnect` when you are
completely done with the voice protocol instance.
This method removes it from the internal state cache that keeps track of
currently alive voice clients. Failure to clean-up will cause subsequent
connections to report that it's still connected.
"""
key_id, _ = self.channel._get_voice_client_key()
self.client._connection._remove_voice_client(key_id)
class VoiceClient(VoiceProtocol):
"""Represents a Discord voice connection. """Represents a Discord voice connection.
You do not create these, you typically get them from You do not create these, you typically get them from
@ -85,14 +188,13 @@ class VoiceClient:
loop: :class:`asyncio.AbstractEventLoop` loop: :class:`asyncio.AbstractEventLoop`
The event loop that the voice client is running on. The event loop that the voice client is running on.
""" """
def __init__(self, state, timeout, channel): def __init__(self, client, channel):
if not has_nacl: if not has_nacl:
raise RuntimeError("PyNaCl library needed in order to use voice") raise RuntimeError("PyNaCl library needed in order to use voice")
self.channel = channel super().__init__(client, channel)
self.main_ws = None state = client._connection
self.timeout = timeout self.token = None
self.ws = None
self.socket = None self.socket = None
self.loop = state.loop self.loop = state.loop
self._state = state self._state = state
@ -100,8 +202,8 @@ class VoiceClient:
self._connected = threading.Event() self._connected = threading.Event()
self._handshaking = False self._handshaking = False
self._handshake_check = asyncio.Lock() self._voice_state_complete = asyncio.Event()
self._handshake_complete = asyncio.Event() self._voice_server_complete = asyncio.Event()
self.mode = None self.mode = None
self._connections = 0 self._connections = 0
@ -138,48 +240,24 @@ class VoiceClient:
# connection related # connection related
async def start_handshake(self): async def on_voice_state_update(self, data):
log.info('Starting voice handshake...') self.session_id = data['session_id']
channel_id = data['channel_id']
guild_id, channel_id = self.channel._get_voice_state_pair()
state = self._state
self.main_ws = ws = state._get_websocket(guild_id)
self._connections += 1
# request joining
await ws.voice_state(guild_id, channel_id)
try:
await asyncio.wait_for(self._handshake_complete.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
await self.terminate_handshake(remove=True)
raise
log.info('Voice handshake complete. Endpoint found %s (IP: %s)', self.endpoint, self.endpoint_ip)
async def terminate_handshake(self, *, remove=False): if not self._handshaking:
guild_id, channel_id = self.channel._get_voice_state_pair() # If we're done handshaking then we just need to update ourselves
self._handshake_complete.clear() guild = self.guild
await self.main_ws.voice_state(guild_id, None, self_mute=True) self.channel = channel_id and guild and guild.get_channel(int(channel_id))
self._handshaking = False else:
self._voice_state_complete.set()
log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', channel_id, guild_id) async def on_voice_server_update(self, data):
if remove: if self._voice_server_complete.is_set():
log.info('The voice client has been removed for Channel ID %s (Guild ID %s)', channel_id, guild_id) log.info('Ignoring extraneous voice server update.')
key_id, _ = self.channel._get_voice_client_key() return
self._state._remove_voice_client(key_id)
async def _create_socket(self, server_id, data):
async with self._handshake_check:
if self._handshaking:
log.info("Ignoring voice server update while handshake is in progress")
return
self._handshaking = True
self._connected.clear()
self.session_id = self.main_ws.session_id
self.server_id = server_id
self.token = data.get('token') self.token = data.get('token')
self.server_id = int(data['guild_id'])
endpoint = data.get('endpoint') endpoint = data.get('endpoint')
if endpoint is None or self.token is None: if endpoint is None or self.token is None:
@ -195,23 +273,77 @@ class VoiceClient:
# This gets set later # This gets set later
self.endpoint_ip = None self.endpoint_ip = None
if self.socket:
try:
self.socket.close()
except Exception:
pass
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.setblocking(False) self.socket.setblocking(False)
if self._handshake_complete.is_set(): if not self._handshaking:
# terminate the websocket and handle the reconnect loop if necessary. # If we're not handshaking then we need to terminate our previous connection in the websocket
self._handshake_complete.clear()
self._handshaking = False
await self.ws.close(4000) await self.ws.close(4000)
return return
self._handshake_complete.set() self._voice_server_complete.set()
async def voice_connect(self):
self._connections += 1
await self.channel.guild.change_voice_state(channel=self.channel)
async def voice_disconnect(self):
log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
await self.channel.guild.change_voice_state(channel=None)
async def connect(self, *, reconnect, timeout):
log.info('Connecting to voice...')
self.timeout = timeout
try:
del self.secret_key
except AttributeError:
pass
for i in range(5):
self._voice_state_complete.clear()
self._voice_server_complete.clear()
self._handshaking = True
# This has to be created before we start the flow.
futures = [
self._voice_state_complete.wait(),
self._voice_server_complete.wait(),
]
# Start the connection flow
log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
await self.voice_connect()
try:
await utils.sane_wait_for(futures, timeout=timeout)
except asyncio.TimeoutError:
await self.disconnect(force=True)
raise
log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
self._handshaking = False
self._voice_server_complete.clear()
self._voice_state_complete.clear()
try:
self.ws = await DiscordVoiceWebSocket.from_client(self)
self._connected.clear()
while not hasattr(self, 'secret_key'):
await self.ws.poll_event()
self._connected.set()
break
except (ConnectionClosed, asyncio.TimeoutError):
if reconnect:
log.exception('Failed to connect to voice... Retrying...')
await asyncio.sleep(1 + i * 2.0)
await self.voice_disconnect()
continue
else:
raise
if self._runner is None:
self._runner = self.loop.create_task(self.poll_voice_ws(reconnect))
@property @property
def latency(self): def latency(self):
@ -234,35 +366,6 @@ class VoiceClient:
ws = self.ws ws = self.ws
return float("inf") if not ws else ws.average_latency return float("inf") if not ws else ws.average_latency
async def connect(self, *, reconnect=True, _tries=0, do_handshake=True):
log.info('Connecting to voice...')
try:
del self.secret_key
except AttributeError:
pass
if do_handshake:
await self.start_handshake()
try:
self.ws = await DiscordVoiceWebSocket.from_client(self)
self._handshaking = False
self._connected.clear()
while not hasattr(self, 'secret_key'):
await self.ws.poll_event()
self._connected.set()
except (ConnectionClosed, asyncio.TimeoutError):
if reconnect and _tries < 5:
log.exception('Failed to connect to voice... Retrying...')
await asyncio.sleep(1 + _tries * 2.0)
await self.terminate_handshake()
await self.connect(reconnect=reconnect, _tries=_tries + 1)
else:
raise
if self._runner is None:
self._runner = self.loop.create_task(self.poll_voice_ws(reconnect))
async def poll_voice_ws(self, reconnect): async def poll_voice_ws(self, reconnect):
backoff = ExponentialBackoff() backoff = ExponentialBackoff()
while True: while True:
@ -287,9 +390,9 @@ class VoiceClient:
log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry) log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
self._connected.clear() self._connected.clear()
await asyncio.sleep(retry) await asyncio.sleep(retry)
await self.terminate_handshake() await self.voice_disconnect()
try: try:
await self.connect(reconnect=True) await self.connect(reconnect=True, timeout=self.timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
# at this point we've retried 5 times... let's continue the loop. # at this point we've retried 5 times... let's continue the loop.
log.warning('Could not connect to voice... Retrying...') log.warning('Could not connect to voice... Retrying...')
@ -310,8 +413,9 @@ class VoiceClient:
if self.ws: if self.ws:
await self.ws.close() await self.ws.close()
await self.terminate_handshake(remove=True) await self.voice_disconnect()
finally: finally:
self.cleanup()
if self.socket: if self.socket:
self.socket.close() self.socket.close()
@ -325,8 +429,7 @@ class VoiceClient:
channel: :class:`abc.Snowflake` channel: :class:`abc.Snowflake`
The channel to move to. Must be a voice channel. The channel to move to. Must be a voice channel.
""" """
guild_id, _ = self.channel._get_voice_state_pair() await self.channel.guild.change_voice_state(channel=channel)
await self.main_ws.voice_state(guild_id, channel.id)
def is_connected(self): def is_connected(self):
"""Indicates if the voice client is connected to voice.""" """Indicates if the voice client is connected to voice."""

3
docs/api.rst

@ -54,6 +54,9 @@ Voice
.. autoclass:: VoiceClient() .. autoclass:: VoiceClient()
:members: :members:
.. autoclass:: VoiceProtocol
:members:
.. autoclass:: AudioSource .. autoclass:: AudioSource
:members: :members:

Loading…
Cancel
Save