diff --git a/discord/gateway.py b/discord/gateway.py index 4e1f78c68..ccb58f0a6 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -44,6 +44,11 @@ from .activity import BaseActivity from .enums import SpeakingState from .errors import ConnectionClosed +try: + import davey # type: ignore +except ImportError: + pass + _log = logging.getLogger(__name__) __all__ = ( @@ -812,18 +817,30 @@ class DiscordVoiceWebSocket: _max_heartbeat_timeout: float # fmt: off - IDENTIFY = 0 - SELECT_PROTOCOL = 1 - READY = 2 - HEARTBEAT = 3 - SESSION_DESCRIPTION = 4 - SPEAKING = 5 - HEARTBEAT_ACK = 6 - RESUME = 7 - HELLO = 8 - RESUMED = 9 - CLIENT_CONNECT = 12 - CLIENT_DISCONNECT = 13 + IDENTIFY = 0 + SELECT_PROTOCOL = 1 + READY = 2 + HEARTBEAT = 3 + SESSION_DESCRIPTION = 4 + SPEAKING = 5 + HEARTBEAT_ACK = 6 + RESUME = 7 + HELLO = 8 + RESUMED = 9 + CLIENTS_CONNECT = 11 + CLIENT_CONNECT = 12 + CLIENT_DISCONNECT = 13 + DAVE_PREPARE_TRANSITION = 21 + DAVE_EXECUTE_TRANSITION = 22 + DAVE_TRANSITION_READY = 23 + DAVE_PREPARE_EPOCH = 24 + MLS_EXTERNAL_SENDER = 25 + MLS_KEY_PACKAGE = 26 + MLS_PROPOSALS = 27 + MLS_COMMIT_WELCOME = 28 + MLS_ANNOUNCE_COMMIT_TRANSITION = 29 + MLS_WELCOME = 30 + MLS_INVALID_COMMIT_WELCOME = 31 # fmt: on def __init__( @@ -850,6 +867,10 @@ class DiscordVoiceWebSocket: _log.debug('Sending voice websocket frame: %s.', data) await self.ws.send_str(utils._to_json(data)) + async def send_binary(self, opcode: int, data: bytes) -> None: + _log.debug('Sending voice websocket binary frame: opcode=%s size=%d', opcode, len(data)) + await self.ws.send_bytes(bytes([opcode]) + data) + send_heartbeat = send_as_json async def resume(self) -> None: @@ -874,6 +895,7 @@ class DiscordVoiceWebSocket: 'user_id': str(state.user.id), 'session_id': state.session_id, 'token': state.token, + 'max_dave_protocol_version': state.max_dave_protocol_version, }, } await self.send_as_json(payload) @@ -943,6 +965,16 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) + async def send_transition_ready(self, transition_id: int): + payload = { + 'op': DiscordVoiceWebSocket.DAVE_TRANSITION_READY, + 'd': { + 'transition_id': transition_id, + }, + } + + await self.send_as_json(payload) + async def received_message(self, msg: Dict[str, Any]) -> None: _log.debug('Voice websocket frame received: %s', msg) op = msg['op'] @@ -959,13 +991,84 @@ class DiscordVoiceWebSocket: elif op == self.SESSION_DESCRIPTION: self._connection.mode = data['mode'] await self.load_secret_key(data) + self._connection.dave_protocol_version = data['dave_protocol_version'] + if data['dave_protocol_version'] > 0: + await self._connection.reinit_dave_session() elif op == self.HELLO: interval = data['heartbeat_interval'] / 1000.0 self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0)) self._keep_alive.start() + elif self._connection.dave_session: + state = self._connection + if op == self.DAVE_PREPARE_TRANSITION: + _log.debug( + 'Preparing for DAVE transition id %d for protocol version %d', + data['transition_id'], + data['protocol_version'], + ) + state.dave_pending_transitions[data['transition_id']] = data['protocol_version'] + if data['transition_id'] == 0: + await state._execute_transition(data['transition_id']) + else: + if data['protocol_version'] == 0 and state.dave_session: + state.dave_session.set_passthrough_mode(True, 120) + + await self.send_transition_ready(data['transition_id']) + elif op == self.DAVE_EXECUTE_TRANSITION: + _log.debug('Executing DAVE transition id %d', data['transition_id']) + await state._execute_transition(data['transition_id']) + elif op == self.DAVE_PREPARE_EPOCH: + _log.debug('Preparing for DAVE epoch %d', data['epoch']) + # When the epoch ID is equal to 1, this message indicates that a new MLS group is to be created for the given protocol version. + if data['epoch'] == 1: + state.dave_protocol_version = data['protocol_version'] + await state.reinit_dave_session() await self._hook(self, msg) + async def recieved_binary_message(self, msg: bytes) -> None: + self.seq_ack = struct.unpack_from('>H', msg, 0)[0] + op = msg[2] + _log.debug('Voice websocket binary frame received: %d bytes; seq=%s op=%s', len(msg), self.seq_ack, op) + state = self._connection + + if state.dave_session: + if op == self.MLS_EXTERNAL_SENDER: + state.dave_session.set_external_sender(msg[3:]) + _log.debug('Set MLS external sender') + elif op == self.MLS_PROPOSALS: + optype = msg[3] + result = state.dave_session.process_proposals( + davey.ProposalsOperationType.append if optype == 0 else davey.ProposalsOperationType.revoke, msg[4:] + ) + if isinstance(result, davey.CommitWelcome): + await self.send_binary( + DiscordVoiceWebSocket.MLS_COMMIT_WELCOME, + result.commit + result.welcome if result.welcome else result.commit, + ) + _log.debug('MLS proposals processed') + elif op == self.MLS_ANNOUNCE_COMMIT_TRANSITION: + transition_id = struct.unpack_from('>H', msg, 3)[0] + try: + state.dave_session.process_commit(msg[5:]) + if transition_id != 0: + state.dave_pending_transitions[transition_id] = state.dave_protocol_version + await self.send_transition_ready(transition_id) + _log.debug('MLS commit processed for transition id %d', transition_id) + except Exception: + await state._recover_from_invalid_commit(transition_id) + elif op == self.MLS_WELCOME: + transition_id = struct.unpack_from('>H', msg, 3)[0] + try: + state.dave_session.process_welcome(msg[5:]) + if transition_id != 0: + state.dave_pending_transitions[transition_id] = state.dave_protocol_version + await self.send_transition_ready(transition_id) + _log.debug('MLS welcome processed for transition id %d', transition_id) + except Exception: + await state._recover_from_invalid_commit(transition_id) + pass + async def initial_connection(self, data: Dict[str, Any]) -> None: state = self._connection state.ssrc = data['ssrc'] @@ -1045,6 +1148,8 @@ class DiscordVoiceWebSocket: msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) if msg.type is aiohttp.WSMsgType.TEXT: await self.received_message(utils._from_json(msg.data)) + elif msg.type is aiohttp.WSMsgType.BINARY: + await self.recieved_binary_message(msg.data) elif msg.type is aiohttp.WSMsgType.ERROR: _log.debug('Received voice %s', msg) raise ConnectionClosed(self.ws, shard_id=None) from msg.data diff --git a/discord/voice_client.py b/discord/voice_client.py index b0f3e951b..aeb549b1b 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -284,6 +284,15 @@ class VoiceClient(VoiceProtocol): def timeout(self) -> float: return self._connection.timeout + @property + def voice_privacy_code(self) -> Optional[str]: + """:class:`str`: Get the voice privacy code of this E2EE session's group. + + A new privacy code is created and cached each time a new transition is executed. + This can be None if there is no active DAVE session happening. + """ + return self._connection.dave_session.voice_privacy_code if self._connection.dave_session else None + def checked_add(self, attr: str, value: int, limit: int) -> None: val = getattr(self, attr) if val + value > limit: @@ -368,7 +377,12 @@ class VoiceClient(VoiceProtocol): # audio related - def _get_voice_packet(self, data): + def _get_voice_packet(self, data: bytes): + packet = ( + self._connection.dave_session.encrypt_opus(data) + if self._connection.dave_session and self._connection.can_encrypt + else data + ) header = bytearray(12) # Formulate rtp header @@ -379,7 +393,7 @@ class VoiceClient(VoiceProtocol): struct.pack_into('>I', header, 8, self.ssrc) encrypt_packet = getattr(self, '_encrypt_' + self.mode) - return encrypt_packet(header, data) + return encrypt_packet(header, packet) def _encrypt_aead_xchacha20_poly1305_rtpsize(self, header: bytes, data) -> bytes: # Esentially the same as _lite diff --git a/discord/voice_state.py b/discord/voice_state.py index 5e78c7851..04cc11b61 100644 --- a/discord/voice_state.py +++ b/discord/voice_state.py @@ -69,6 +69,14 @@ if TYPE_CHECKING: WebsocketHook = Optional[Callable[[DiscordVoiceWebSocket, Dict[str, Any]], Coroutine[Any, Any, Any]]] SocketReaderCallback = Callable[[bytes], Any] +has_dave: bool + +try: + import davey # type: ignore + + has_dave = True +except ImportError: + has_dave = False __all__ = ('VoiceConnectionState',) @@ -208,6 +216,10 @@ class VoiceConnectionState: self.mode: SupportedModes = MISSING self.socket: socket.socket = MISSING self.ws: DiscordVoiceWebSocket = MISSING + self.dave_session: Optional[davey.DaveSession] = None + self.dave_protocol_version: int = 0 + self.dave_pending_transitions: Dict[int, int] = {} + self.dave_downgraded: bool = False self._state: ConnectionFlowState = ConnectionFlowState.disconnected self._expecting_disconnect: bool = False @@ -252,6 +264,64 @@ class VoiceConnectionState: def self_voice_state(self) -> Optional[VoiceState]: return self.guild.me.voice + @property + def max_dave_protocol_version(self) -> int: + return davey.DAVE_PROTOCOL_VERSION if has_dave else 0 + + @property + def can_encrypt(self) -> bool: + return self.dave_protocol_version != 0 and self.dave_session != None and self.dave_session.ready + + async def reinit_dave_session(self) -> None: + if self.dave_protocol_version > 0: + if not has_dave: + raise RuntimeError('davey library needed in order to use E2EE voice') + if self.dave_session is not None: + self.dave_session.reinit(self.dave_protocol_version, self.user.id, self.voice_client.channel.id) + else: + self.dave_session = davey.DaveSession(self.dave_protocol_version, self.user.id, self.voice_client.channel.id) + + if self.dave_session is not None: + await self.voice_client.ws.send_binary( + DiscordVoiceWebSocket.MLS_KEY_PACKAGE, self.dave_session.get_serialized_key_package() + ) + elif self.dave_session: + self.dave_session.reset() + self.dave_session.set_passthrough_mode(True, 10) + pass + + async def _recover_from_invalid_commit(self, transition_id: int) -> None: + payload = { + 'op': DiscordVoiceWebSocket.MLS_INVALID_COMMIT_WELCOME, + 'd': { + 'transition_id': transition_id, + }, + } + + await self.voice_client.ws.send_as_json(payload) + await self.reinit_dave_session() + + async def _execute_transition(self, transition_id: int) -> None: + _log.debug('Executing transition id %d', transition_id) + if transition_id not in self.dave_pending_transitions: + _log.warning("Received execute transition, but we don't have a pending transition for id %d", transition_id) + return + + old_version = self.dave_protocol_version + self.dave_protocol_version = self.dave_pending_transitions.pop(transition_id) + + if old_version != self.dave_protocol_version and self.dave_protocol_version == 0: + self.dave_downgraded = True + _log.debug('DAVE Session downgraded') + elif transition_id > 0 and self.dave_downgraded: + self.dave_downgraded = False + if self.dave_session: + self.dave_session.set_passthrough_mode(True, 10) + _log.debug('DAVE Session upgraded') + + # In the future, the session should be signaled too, but for now theres just v1 + _log.debug('Transition id %d executed', transition_id) + async def voice_state_update(self, data: GuildVoiceStatePayload) -> None: channel_id = data['channel_id'] diff --git a/pyproject.toml b/pyproject.toml index d32ed9a29..1c45d5f7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,10 @@ Documentation = "https://discordpy.readthedocs.io/en/latest/" dependencies = { file = "requirements.txt" } [project.optional-dependencies] -voice = ["PyNaCl>=1.5.0,<1.6"] +voice = [ + "PyNaCl>=1.5.0,<1.6", + "davey==0.1.0" +] docs = [ "sphinx==4.4.0", "sphinxcontrib_trio==1.1.2",