Browse Source

Working multi-server voice support.

pull/198/head
Rapptz 9 years ago
parent
commit
d9c780b8a8
  1. 73
      discord/client.py
  2. 61
      discord/gateway.py
  3. 21
      discord/state.py
  4. 7
      discord/voice_client.py

73
discord/client.py

@ -90,10 +90,10 @@ class Client:
----------- -----------
user : Optional[:class:`User`] user : Optional[:class:`User`]
Represents the connected client. None if not logged in. Represents the connected client. None if not logged in.
voice : Optional[:class:`VoiceClient`] voice_clients : iterable of :class:`VoiceClient`
Represents the current voice connection. None if you are not connected Represents a list of voice connections. To connect to voice use
to a voice channel. To connect to voice use :meth:`join_voice_channel`. :meth:`join_voice_channel`. To query the voice connection state use
To query the voice connection state use :meth:`is_voice_connected`. :meth:`is_voice_connected`.
servers : iterable of :class:`Server` servers : iterable of :class:`Server`
The servers that the connected client is a member of. The servers that the connected client is a member of.
private_channels : iterable of :class:`PrivateChannel` private_channels : iterable of :class:`PrivateChannel`
@ -114,7 +114,6 @@ class Client:
def __init__(self, *, loop=None, **options): def __init__(self, *, loop=None, **options):
self.ws = None self.ws = None
self.token = None self.token = None
self.voice = None
self.loop = asyncio.get_event_loop() if loop is None else loop self.loop = asyncio.get_event_loop() if loop is None else loop
self._listeners = [] self._listeners = []
self.cache_auth = options.get('cache_auth', True) self.cache_auth = options.get('cache_auth', True)
@ -227,14 +226,14 @@ class Client:
raise InvalidArgument('Destination must be Channel, PrivateChannel, User, or Object') raise InvalidArgument('Destination must be Channel, PrivateChannel, User, or Object')
def __getattr__(self, name): def __getattr__(self, name):
if name in ('user', 'servers', 'private_channels', 'messages'): if name in ('user', 'servers', 'private_channels', 'messages', 'voice_clients'):
return getattr(self.connection, name) return getattr(self.connection, name)
else: else:
msg = "'{}' object has no attribute '{}'" msg = "'{}' object has no attribute '{}'"
raise AttributeError(msg.format(self.__class__, name)) raise AttributeError(msg.format(self.__class__, name))
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name in ('user', 'servers', 'private_channels', 'messages'): if name in ('user', 'servers', 'private_channels', 'messages', 'voice_clients'):
return setattr(self.connection, name, value) return setattr(self.connection, name, value)
else: else:
object.__setattr__(self, name, value) object.__setattr__(self, name, value)
@ -418,13 +417,13 @@ class Client:
if self.is_closed: if self.is_closed:
return return
if self.is_voice_connected():
yield from self.voice.disconnect()
self.voice = None
if self.ws is not None and self.ws.open: if self.ws is not None and self.ws.open:
yield from self.ws.close() yield from self.ws.close()
for voice in list(self.voice_clients):
yield from voice.disconnect()
self.connection._remove_voice_client(voice.server.id)
yield from self.session.close() yield from self.session.close()
self._closed.set() self._closed.set()
self._is_ready.clear() self._is_ready.clear()
@ -2415,15 +2414,17 @@ class Client:
:class:`VoiceClient` :class:`VoiceClient`
A voice client that is fully connected to the voice server. A voice client that is fully connected to the voice server.
""" """
if self.is_voice_connected():
raise ClientException('Already connected to a voice channel')
if isinstance(channel, Object): if isinstance(channel, Object):
channel = self.get_channel(channel.id) channel = self.get_channel(channel.id)
if getattr(channel, 'type', ChannelType.text) != ChannelType.voice: if getattr(channel, 'type', ChannelType.text) != ChannelType.voice:
raise InvalidArgument('Channel passed must be a voice channel') raise InvalidArgument('Channel passed must be a voice channel')
server = channel.server
if self.is_voice_connected(server):
raise ClientException('Already connected to a voice channel in this server')
log.info('attempting to join voice channel {0.name}'.format(channel)) log.info('attempting to join voice channel {0.name}'.format(channel))
def session_id_found(data): def session_id_found(data):
@ -2435,14 +2436,10 @@ class Client:
voice_data_future = self.ws.wait_for('VOICE_SERVER_UPDATE', lambda d: True) voice_data_future = self.ws.wait_for('VOICE_SERVER_UPDATE', lambda d: True)
# request joining # request joining
yield from self.ws.voice_state(channel.server.id, channel.id) yield from self.ws.voice_state(server.id, channel.id)
session_id_data = yield from asyncio.wait_for(session_id_future, timeout=10.0, loop=self.loop) session_id_data = yield from asyncio.wait_for(session_id_future, timeout=10.0, loop=self.loop)
data = yield from asyncio.wait_for(voice_data_future, timeout=10.0, loop=self.loop) data = yield from asyncio.wait_for(voice_data_future, timeout=10.0, loop=self.loop)
# todo: multivoice
if self.is_voice_connected():
self.voice.channel = self.get_channel(session_id_data.get('channel_id'))
kwargs = { kwargs = {
'user': self.user, 'user': self.user,
'channel': channel, 'channel': channel,
@ -2452,10 +2449,36 @@ class Client:
'main_ws': self.ws 'main_ws': self.ws
} }
self.voice = VoiceClient(**kwargs) voice = VoiceClient(**kwargs)
yield from self.voice.connect() yield from voice.connect()
return self.voice self.connection._add_voice_client(server.id, voice)
return voice
def is_voice_connected(self, server):
"""Indicates if we are currently connected to a voice channel in the
specified server.
Parameters
-----------
server : :class:`Server`
The server to query if we're connected to it.
"""
voice = self.voice_client_in(server)
return voice is not None
def voice_client_in(self, server):
"""Returns the voice client associated with a server.
If no voice client is found then ``None`` is returned.
def is_voice_connected(self): Parameters
"""bool : Indicates if we are currently connected to a voice channel.""" -----------
return self.voice is not None and self.voice.is_connected() server : :class:`Server`
The server to query if we have a voice client for.
Returns
--------
:class:`VoiceClient`
The voice client associated with the server.
"""
return self.connection._get_voice_client(server.id)

61
discord/gateway.py

@ -179,35 +179,21 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# the keep alive # the keep alive
self._keep_alive = None self._keep_alive = None
@classmethod @classmethod
@asyncio.coroutine @asyncio.coroutine
def connect(cls, dispatch, *, token=None, connection=None, loop=None): def from_client(cls, client):
"""Creates a main websocket for Discord used for the client. """Creates a main websocket for Discord from a :class:`Client`.
Parameters
----------
token : str
The token for Discord authentication.
connection
The ConnectionState for the client.
dispatch
The function that dispatches events.
loop
The event loop to use.
Returns This is for internal use only.
-------
DiscordWebSocket
A websocket connected to Discord.
""" """
gateway = yield from get_gateway(client.token, loop=client.loop)
gateway = yield from get_gateway(token, loop=loop) ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
ws = yield from websockets.connect(gateway, loop=loop, klass=cls)
# dynamically add attributes needed # dynamically add attributes needed
ws.token = token ws.token = client.token
ws._connection = connection ws._connection = client.connection
ws._dispatch = dispatch ws._dispatch = client.dispatch
ws.gateway = gateway ws.gateway = gateway
log.info('Created websocket connected to {}'.format(gateway)) log.info('Created websocket connected to {}'.format(gateway))
@ -215,16 +201,6 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
log.info('sent the identify payload to create the websocket') log.info('sent the identify payload to create the websocket')
return ws return ws
@classmethod
def from_client(cls, client):
"""Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only.
"""
return cls.connect(client.dispatch, token=client.token,
connection=client.connection,
loop=client.loop)
def wait_for(self, event, predicate, result=None): def wait_for(self, event, predicate, result=None):
"""Waits for a DISPATCH'd event that meets the predicate. """Waits for a DISPATCH'd event that meets the predicate.
@ -280,6 +256,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
msg = msg.decode('utf-8') msg = msg.decode('utf-8')
msg = json.loads(msg) msg = json.loads(msg)
state = self._connection
log.debug('WebSocket Event: {}'.format(msg)) log.debug('WebSocket Event: {}'.format(msg))
self._dispatch('socket_response', msg) self._dispatch('socket_response', msg)
@ -288,7 +265,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
data = msg.get('d') data = msg.get('d')
if 's' in msg: if 's' in msg:
self._connection.sequence = msg['s'] state.sequence = msg['s']
if op == self.RECONNECT: if op == self.RECONNECT:
# "reconnect" can only be handled by the Client # "reconnect" can only be handled by the Client
@ -299,8 +276,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
raise ReconnectWebSocket() raise ReconnectWebSocket()
if op == self.INVALIDATE_SESSION: if op == self.INVALIDATE_SESSION:
self._connection.sequence = None state.sequence = None
self._connection.session_id = None state.session_id = None
return return
if op != self.DISPATCH: if op != self.DISPATCH:
@ -311,9 +288,9 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
is_ready = event == 'READY' is_ready = event == 'READY'
if is_ready: if is_ready:
self._connection.clear() state.clear()
self._connection.sequence = msg['s'] state.sequence = msg['s']
self._connection.session_id = data['session_id'] state.session_id = data['session_id']
if is_ready or event == 'RESUMED': if is_ready or event == 'RESUMED':
interval = data['heartbeat_interval'] / 1000.0 interval = data['heartbeat_interval'] / 1000.0
@ -366,7 +343,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
msg = yield from self.recv() msg = yield from self.recv()
yield from self.received_message(msg) yield from self.received_message(msg)
except websockets.exceptions.ConnectionClosed as e: except websockets.exceptions.ConnectionClosed as e:
if e.code in (4008, 4009) or e.code in range(1001, 1015): if e.code in (4006, 4008, 4009) or e.code in range(1001, 1015):
log.info('Websocket closed with {0.code}, attempting a reconnect.'.format(e)) log.info('Websocket closed with {0.code}, attempting a reconnect.'.format(e))
raise ReconnectWebSocket() from e raise ReconnectWebSocket() from e
else: else:
@ -424,6 +401,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
yield from self.send_as_json(payload) yield from self.send_as_json(payload)
# we're leaving a voice channel so remove it from the client list
if channel_id is None:
self._connection._remove_voice_client(guild_id)
@asyncio.coroutine @asyncio.coroutine
def close(self, code=1000, reason=''): def close(self, code=1000, reason=''):
if self._keep_alive: if self._keep_alive:

21
discord/state.py

@ -62,6 +62,7 @@ class ConnectionState:
self.sequence = None self.sequence = None
self.session_id = None self.session_id = None
self._servers = {} self._servers = {}
self._voice_clients = {}
self._private_channels = {} self._private_channels = {}
# extra dict to look up private channels by user id # extra dict to look up private channels by user id
self._private_channels_by_user = {} self._private_channels_by_user = {}
@ -93,6 +94,19 @@ class ConnectionState:
for index in reversed(removed): for index in reversed(removed):
del self._listeners[index] del self._listeners[index]
@property
def voice_clients(self):
return self._voice_clients.values()
def _get_voice_client(self, guild_id):
return self._voice_clients.get(guild_id)
def _add_voice_client(self, guild_id, voice):
self._voice_clients[guild_id] = voice
def _remove_voice_client(self, guild_id):
self._voice_clients.pop(guild_id, None)
@property @property
def servers(self): def servers(self):
return self._servers.values() return self._servers.values()
@ -130,6 +144,7 @@ class ConnectionState:
def _add_server_from_data(self, guild): def _add_server_from_data(self, guild):
server = Server(**guild) server = Server(**guild)
Server.me = property(lambda s: s.get_member(self.user.id)) Server.me = property(lambda s: s.get_member(self.user.id))
Server.voice_client = property(lambda s: self._get_voice_client(s.id))
self._add_server(server) self._add_server(server)
return server return server
@ -489,7 +504,13 @@ class ConnectionState:
def parse_voice_state_update(self, data): def parse_voice_state_update(self, data):
server = self._get_server(data.get('guild_id')) server = self._get_server(data.get('guild_id'))
user_id = data.get('user_id')
if server is not None: if server is not None:
if user_id == self.user.id:
voice = self._get_voice_client(server.id)
if voice is not None:
voice.channel = server.get_channel(data.get('channel_id'))
updated_members = server._update_voice_state(data) updated_members = server._update_voice_state(data)
self.dispatch('voice_state_update', *updated_members) self.dispatch('voice_state_update', *updated_members)

7
discord/voice_client.py

@ -158,6 +158,9 @@ class VoiceClient:
The endpoint we are connecting to. The endpoint we are connecting to.
channel : :class:`Channel` channel : :class:`Channel`
The voice channel connected to. The voice channel connected to.
server : :class:`Server`
The server the voice channel is connected to.
Shorthand for ``channel.server``.
loop loop
The event loop that the voice client is running on. The event loop that the voice client is running on.
""" """
@ -176,6 +179,10 @@ class VoiceClient:
self.encoder = OpusEncoder(48000, 2) self.encoder = OpusEncoder(48000, 2)
log.info('created opus encoder with {0.__dict__}'.format(self.encoder)) log.info('created opus encoder with {0.__dict__}'.format(self.encoder))
@property
def server(self):
return self.channel.server
def checked_add(self, attr, value, limit): def checked_add(self, attr, value, limit):
val = getattr(self, attr) val = getattr(self, attr)
if val + value > limit: if val + value > limit:

Loading…
Cancel
Save