diff --git a/discord/client.py b/discord/client.py index 4e364af98..3d7b8afaa 100644 --- a/discord/client.py +++ b/discord/client.py @@ -90,10 +90,10 @@ class Client: ----------- user : Optional[:class:`User`] Represents the connected client. None if not logged in. - voice : Optional[:class:`VoiceClient`] - Represents the current voice connection. None if you are not connected - to a voice channel. To connect to voice use :meth:`join_voice_channel`. - To query the voice connection state use :meth:`is_voice_connected`. + voice_clients : iterable of :class:`VoiceClient` + Represents a list of voice connections. To connect to voice use + :meth:`join_voice_channel`. To query the voice connection state use + :meth:`is_voice_connected`. servers : iterable of :class:`Server` The servers that the connected client is a member of. private_channels : iterable of :class:`PrivateChannel` @@ -114,7 +114,6 @@ class Client: def __init__(self, *, loop=None, **options): self.ws = None self.token = None - self.voice = None self.loop = asyncio.get_event_loop() if loop is None else loop self._listeners = [] self.cache_auth = options.get('cache_auth', True) @@ -227,14 +226,14 @@ class Client: raise InvalidArgument('Destination must be Channel, PrivateChannel, User, or Object') def __getattr__(self, name): - if name in ('user', 'servers', 'private_channels', 'messages'): + if name in ('user', 'servers', 'private_channels', 'messages', 'voice_clients'): return getattr(self.connection, name) else: msg = "'{}' object has no attribute '{}'" raise AttributeError(msg.format(self.__class__, name)) def __setattr__(self, name, value): - if name in ('user', 'servers', 'private_channels', 'messages'): + if name in ('user', 'servers', 'private_channels', 'messages', 'voice_clients'): return setattr(self.connection, name, value) else: object.__setattr__(self, name, value) @@ -418,13 +417,13 @@ class Client: if self.is_closed: return - if self.is_voice_connected(): - yield from self.voice.disconnect() - self.voice = None - if self.ws is not None and self.ws.open: yield from self.ws.close() + for voice in list(self.voice_clients): + yield from voice.disconnect() + self.connection._remove_voice_client(voice.server.id) + yield from self.session.close() self._closed.set() self._is_ready.clear() @@ -2415,15 +2414,17 @@ class Client: :class:`VoiceClient` A voice client that is fully connected to the voice server. """ - if self.is_voice_connected(): - raise ClientException('Already connected to a voice channel') - if isinstance(channel, Object): channel = self.get_channel(channel.id) if getattr(channel, 'type', ChannelType.text) != ChannelType.voice: raise InvalidArgument('Channel passed must be a voice channel') + server = channel.server + + if self.is_voice_connected(server): + raise ClientException('Already connected to a voice channel in this server') + log.info('attempting to join voice channel {0.name}'.format(channel)) def session_id_found(data): @@ -2435,14 +2436,10 @@ class Client: voice_data_future = self.ws.wait_for('VOICE_SERVER_UPDATE', lambda d: True) # request joining - yield from self.ws.voice_state(channel.server.id, channel.id) + yield from self.ws.voice_state(server.id, channel.id) session_id_data = yield from asyncio.wait_for(session_id_future, timeout=10.0, loop=self.loop) data = yield from asyncio.wait_for(voice_data_future, timeout=10.0, loop=self.loop) - # todo: multivoice - if self.is_voice_connected(): - self.voice.channel = self.get_channel(session_id_data.get('channel_id')) - kwargs = { 'user': self.user, 'channel': channel, @@ -2452,10 +2449,36 @@ class Client: 'main_ws': self.ws } - self.voice = VoiceClient(**kwargs) - yield from self.voice.connect() - return self.voice + voice = VoiceClient(**kwargs) + yield from voice.connect() + self.connection._add_voice_client(server.id, voice) + return voice + + def is_voice_connected(self, server): + """Indicates if we are currently connected to a voice channel in the + specified server. + + Parameters + ----------- + server : :class:`Server` + The server to query if we're connected to it. + """ + voice = self.voice_client_in(server) + return voice is not None + + def voice_client_in(self, server): + """Returns the voice client associated with a server. + + If no voice client is found then ``None`` is returned. - def is_voice_connected(self): - """bool : Indicates if we are currently connected to a voice channel.""" - return self.voice is not None and self.voice.is_connected() + Parameters + ----------- + server : :class:`Server` + The server to query if we have a voice client for. + + Returns + -------- + :class:`VoiceClient` + The voice client associated with the server. + """ + return self.connection._get_voice_client(server.id) diff --git a/discord/gateway.py b/discord/gateway.py index ccfc98df9..f4c7d2277 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -179,35 +179,21 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): # the keep alive self._keep_alive = None + @classmethod @asyncio.coroutine - def connect(cls, dispatch, *, token=None, connection=None, loop=None): - """Creates a main websocket for Discord used for the client. - - Parameters - ---------- - token : str - The token for Discord authentication. - connection - The ConnectionState for the client. - dispatch - The function that dispatches events. - loop - The event loop to use. + def from_client(cls, client): + """Creates a main websocket for Discord from a :class:`Client`. - Returns - ------- - DiscordWebSocket - A websocket connected to Discord. + This is for internal use only. """ - - gateway = yield from get_gateway(token, loop=loop) - ws = yield from websockets.connect(gateway, loop=loop, klass=cls) + gateway = yield from get_gateway(client.token, loop=client.loop) + ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls) # dynamically add attributes needed - ws.token = token - ws._connection = connection - ws._dispatch = dispatch + ws.token = client.token + ws._connection = client.connection + ws._dispatch = client.dispatch ws.gateway = gateway log.info('Created websocket connected to {}'.format(gateway)) @@ -215,16 +201,6 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): log.info('sent the identify payload to create the websocket') return ws - @classmethod - def from_client(cls, client): - """Creates a main websocket for Discord from a :class:`Client`. - - This is for internal use only. - """ - return cls.connect(client.dispatch, token=client.token, - connection=client.connection, - loop=client.loop) - def wait_for(self, event, predicate, result=None): """Waits for a DISPATCH'd event that meets the predicate. @@ -280,6 +256,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): msg = msg.decode('utf-8') msg = json.loads(msg) + state = self._connection log.debug('WebSocket Event: {}'.format(msg)) self._dispatch('socket_response', msg) @@ -288,7 +265,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): data = msg.get('d') if 's' in msg: - self._connection.sequence = msg['s'] + state.sequence = msg['s'] if op == self.RECONNECT: # "reconnect" can only be handled by the Client @@ -299,8 +276,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): raise ReconnectWebSocket() if op == self.INVALIDATE_SESSION: - self._connection.sequence = None - self._connection.session_id = None + state.sequence = None + state.session_id = None return if op != self.DISPATCH: @@ -311,9 +288,9 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): is_ready = event == 'READY' if is_ready: - self._connection.clear() - self._connection.sequence = msg['s'] - self._connection.session_id = data['session_id'] + state.clear() + state.sequence = msg['s'] + state.session_id = data['session_id'] if is_ready or event == 'RESUMED': interval = data['heartbeat_interval'] / 1000.0 @@ -366,7 +343,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): msg = yield from self.recv() yield from self.received_message(msg) except websockets.exceptions.ConnectionClosed as e: - if e.code in (4008, 4009) or e.code in range(1001, 1015): + if e.code in (4006, 4008, 4009) or e.code in range(1001, 1015): log.info('Websocket closed with {0.code}, attempting a reconnect.'.format(e)) raise ReconnectWebSocket() from e else: @@ -424,6 +401,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): yield from self.send_as_json(payload) + # we're leaving a voice channel so remove it from the client list + if channel_id is None: + self._connection._remove_voice_client(guild_id) + @asyncio.coroutine def close(self, code=1000, reason=''): if self._keep_alive: diff --git a/discord/state.py b/discord/state.py index 64be587fa..8209dbfd5 100644 --- a/discord/state.py +++ b/discord/state.py @@ -62,6 +62,7 @@ class ConnectionState: self.sequence = None self.session_id = None self._servers = {} + self._voice_clients = {} self._private_channels = {} # extra dict to look up private channels by user id self._private_channels_by_user = {} @@ -93,6 +94,19 @@ class ConnectionState: for index in reversed(removed): del self._listeners[index] + @property + def voice_clients(self): + return self._voice_clients.values() + + def _get_voice_client(self, guild_id): + return self._voice_clients.get(guild_id) + + def _add_voice_client(self, guild_id, voice): + self._voice_clients[guild_id] = voice + + def _remove_voice_client(self, guild_id): + self._voice_clients.pop(guild_id, None) + @property def servers(self): return self._servers.values() @@ -130,6 +144,7 @@ class ConnectionState: def _add_server_from_data(self, guild): server = Server(**guild) Server.me = property(lambda s: s.get_member(self.user.id)) + Server.voice_client = property(lambda s: self._get_voice_client(s.id)) self._add_server(server) return server @@ -489,7 +504,13 @@ class ConnectionState: def parse_voice_state_update(self, data): server = self._get_server(data.get('guild_id')) + user_id = data.get('user_id') if server is not None: + if user_id == self.user.id: + voice = self._get_voice_client(server.id) + if voice is not None: + voice.channel = server.get_channel(data.get('channel_id')) + updated_members = server._update_voice_state(data) self.dispatch('voice_state_update', *updated_members) diff --git a/discord/voice_client.py b/discord/voice_client.py index 26b353ccf..4079138fc 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -158,6 +158,9 @@ class VoiceClient: The endpoint we are connecting to. channel : :class:`Channel` The voice channel connected to. + server : :class:`Server` + The server the voice channel is connected to. + Shorthand for ``channel.server``. loop The event loop that the voice client is running on. """ @@ -176,6 +179,10 @@ class VoiceClient: self.encoder = OpusEncoder(48000, 2) log.info('created opus encoder with {0.__dict__}'.format(self.encoder)) + @property + def server(self): + return self.channel.server + def checked_add(self, attr, value, limit): val = getattr(self, attr) if val + value > limit: