Browse Source

Merge 5f257fd016 into c58b973c7e

pull/10300/merge
Snazzah 2 days ago
committed by GitHub
parent
commit
d89717c2b5
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 129
      discord/gateway.py
  2. 18
      discord/voice_client.py
  3. 70
      discord/voice_state.py
  4. 5
      pyproject.toml

129
discord/gateway.py

@ -44,6 +44,11 @@ from .activity import BaseActivity
from .enums import SpeakingState
from .errors import ConnectionClosed
try:
import davey # type: ignore
except ImportError:
pass
_log = logging.getLogger(__name__)
__all__ = (
@ -812,18 +817,30 @@ class DiscordVoiceWebSocket:
_max_heartbeat_timeout: float
# fmt: off
IDENTIFY = 0
SELECT_PROTOCOL = 1
READY = 2
HEARTBEAT = 3
SESSION_DESCRIPTION = 4
SPEAKING = 5
HEARTBEAT_ACK = 6
RESUME = 7
HELLO = 8
RESUMED = 9
CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13
IDENTIFY = 0
SELECT_PROTOCOL = 1
READY = 2
HEARTBEAT = 3
SESSION_DESCRIPTION = 4
SPEAKING = 5
HEARTBEAT_ACK = 6
RESUME = 7
HELLO = 8
RESUMED = 9
CLIENTS_CONNECT = 11
CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13
DAVE_PREPARE_TRANSITION = 21
DAVE_EXECUTE_TRANSITION = 22
DAVE_TRANSITION_READY = 23
DAVE_PREPARE_EPOCH = 24
MLS_EXTERNAL_SENDER = 25
MLS_KEY_PACKAGE = 26
MLS_PROPOSALS = 27
MLS_COMMIT_WELCOME = 28
MLS_ANNOUNCE_COMMIT_TRANSITION = 29
MLS_WELCOME = 30
MLS_INVALID_COMMIT_WELCOME = 31
# fmt: on
def __init__(
@ -850,6 +867,10 @@ class DiscordVoiceWebSocket:
_log.debug('Sending voice websocket frame: %s.', data)
await self.ws.send_str(utils._to_json(data))
async def send_binary(self, opcode: int, data: bytes) -> None:
_log.debug('Sending voice websocket binary frame: opcode=%s size=%d', opcode, len(data))
await self.ws.send_bytes(bytes([opcode]) + data)
send_heartbeat = send_as_json
async def resume(self) -> None:
@ -874,6 +895,7 @@ class DiscordVoiceWebSocket:
'user_id': str(state.user.id),
'session_id': state.session_id,
'token': state.token,
'max_dave_protocol_version': state.max_dave_protocol_version,
},
}
await self.send_as_json(payload)
@ -943,6 +965,16 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload)
async def send_transition_ready(self, transition_id: int):
payload = {
'op': DiscordVoiceWebSocket.DAVE_TRANSITION_READY,
'd': {
'transition_id': transition_id,
},
}
await self.send_as_json(payload)
async def received_message(self, msg: Dict[str, Any]) -> None:
_log.debug('Voice websocket frame received: %s', msg)
op = msg['op']
@ -959,13 +991,84 @@ class DiscordVoiceWebSocket:
elif op == self.SESSION_DESCRIPTION:
self._connection.mode = data['mode']
await self.load_secret_key(data)
self._connection.dave_protocol_version = data['dave_protocol_version']
if data['dave_protocol_version'] > 0:
await self._connection.reinit_dave_session()
elif op == self.HELLO:
interval = data['heartbeat_interval'] / 1000.0
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
self._keep_alive.start()
elif self._connection.dave_session:
state = self._connection
if op == self.DAVE_PREPARE_TRANSITION:
_log.debug(
'Preparing for DAVE transition id %d for protocol version %d',
data['transition_id'],
data['protocol_version'],
)
state.dave_pending_transitions[data['transition_id']] = data['protocol_version']
if data['transition_id'] == 0:
await state._execute_transition(data['transition_id'])
else:
if data['protocol_version'] == 0 and state.dave_session:
state.dave_session.set_passthrough_mode(True, 120)
await self.send_transition_ready(data['transition_id'])
elif op == self.DAVE_EXECUTE_TRANSITION:
_log.debug('Executing DAVE transition id %d', data['transition_id'])
await state._execute_transition(data['transition_id'])
elif op == self.DAVE_PREPARE_EPOCH:
_log.debug('Preparing for DAVE epoch %d', data['epoch'])
# When the epoch ID is equal to 1, this message indicates that a new MLS group is to be created for the given protocol version.
if data['epoch'] == 1:
state.dave_protocol_version = data['protocol_version']
await state.reinit_dave_session()
await self._hook(self, msg)
async def recieved_binary_message(self, msg: bytes) -> None:
self.seq_ack = struct.unpack_from('>H', msg, 0)[0]
op = msg[2]
_log.debug('Voice websocket binary frame received: %d bytes; seq=%s op=%s', len(msg), self.seq_ack, op)
state = self._connection
if state.dave_session:
if op == self.MLS_EXTERNAL_SENDER:
state.dave_session.set_external_sender(msg[3:])
_log.debug('Set MLS external sender')
elif op == self.MLS_PROPOSALS:
optype = msg[3]
result = state.dave_session.process_proposals(
davey.ProposalsOperationType.append if optype == 0 else davey.ProposalsOperationType.revoke, msg[4:]
)
if isinstance(result, davey.CommitWelcome):
await self.send_binary(
DiscordVoiceWebSocket.MLS_COMMIT_WELCOME,
result.commit + result.welcome if result.welcome else result.commit,
)
_log.debug('MLS proposals processed')
elif op == self.MLS_ANNOUNCE_COMMIT_TRANSITION:
transition_id = struct.unpack_from('>H', msg, 3)[0]
try:
state.dave_session.process_commit(msg[5:])
if transition_id != 0:
state.dave_pending_transitions[transition_id] = state.dave_protocol_version
await self.send_transition_ready(transition_id)
_log.debug('MLS commit processed for transition id %d', transition_id)
except Exception:
await state._recover_from_invalid_commit(transition_id)
elif op == self.MLS_WELCOME:
transition_id = struct.unpack_from('>H', msg, 3)[0]
try:
state.dave_session.process_welcome(msg[5:])
if transition_id != 0:
state.dave_pending_transitions[transition_id] = state.dave_protocol_version
await self.send_transition_ready(transition_id)
_log.debug('MLS welcome processed for transition id %d', transition_id)
except Exception:
await state._recover_from_invalid_commit(transition_id)
pass
async def initial_connection(self, data: Dict[str, Any]) -> None:
state = self._connection
state.ssrc = data['ssrc']
@ -1045,6 +1148,8 @@ class DiscordVoiceWebSocket:
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
if msg.type is aiohttp.WSMsgType.TEXT:
await self.received_message(utils._from_json(msg.data))
elif msg.type is aiohttp.WSMsgType.BINARY:
await self.recieved_binary_message(msg.data)
elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug('Received voice %s', msg)
raise ConnectionClosed(self.ws, shard_id=None) from msg.data

18
discord/voice_client.py

@ -284,6 +284,15 @@ class VoiceClient(VoiceProtocol):
def timeout(self) -> float:
return self._connection.timeout
@property
def voice_privacy_code(self) -> Optional[str]:
""":class:`str`: Get the voice privacy code of this E2EE session's group.
A new privacy code is created and cached each time a new transition is executed.
This can be None if there is no active DAVE session happening.
"""
return self._connection.dave_session.voice_privacy_code if self._connection.dave_session else None
def checked_add(self, attr: str, value: int, limit: int) -> None:
val = getattr(self, attr)
if val + value > limit:
@ -368,7 +377,12 @@ class VoiceClient(VoiceProtocol):
# audio related
def _get_voice_packet(self, data):
def _get_voice_packet(self, data: bytes):
packet = (
self._connection.dave_session.encrypt_opus(data)
if self._connection.dave_session and self._connection.can_encrypt
else data
)
header = bytearray(12)
# Formulate rtp header
@ -379,7 +393,7 @@ class VoiceClient(VoiceProtocol):
struct.pack_into('>I', header, 8, self.ssrc)
encrypt_packet = getattr(self, '_encrypt_' + self.mode)
return encrypt_packet(header, data)
return encrypt_packet(header, packet)
def _encrypt_aead_xchacha20_poly1305_rtpsize(self, header: bytes, data) -> bytes:
# Esentially the same as _lite

70
discord/voice_state.py

@ -69,6 +69,14 @@ if TYPE_CHECKING:
WebsocketHook = Optional[Callable[[DiscordVoiceWebSocket, Dict[str, Any]], Coroutine[Any, Any, Any]]]
SocketReaderCallback = Callable[[bytes], Any]
has_dave: bool
try:
import davey # type: ignore
has_dave = True
except ImportError:
has_dave = False
__all__ = ('VoiceConnectionState',)
@ -208,6 +216,10 @@ class VoiceConnectionState:
self.mode: SupportedModes = MISSING
self.socket: socket.socket = MISSING
self.ws: DiscordVoiceWebSocket = MISSING
self.dave_session: Optional[davey.DaveSession] = None
self.dave_protocol_version: int = 0
self.dave_pending_transitions: Dict[int, int] = {}
self.dave_downgraded: bool = False
self._state: ConnectionFlowState = ConnectionFlowState.disconnected
self._expecting_disconnect: bool = False
@ -252,6 +264,64 @@ class VoiceConnectionState:
def self_voice_state(self) -> Optional[VoiceState]:
return self.guild.me.voice
@property
def max_dave_protocol_version(self) -> int:
return davey.DAVE_PROTOCOL_VERSION if has_dave else 0
@property
def can_encrypt(self) -> bool:
return self.dave_protocol_version != 0 and self.dave_session != None and self.dave_session.ready
async def reinit_dave_session(self) -> None:
if self.dave_protocol_version > 0:
if not has_dave:
raise RuntimeError('davey library needed in order to use E2EE voice')
if self.dave_session is not None:
self.dave_session.reinit(self.dave_protocol_version, self.user.id, self.voice_client.channel.id)
else:
self.dave_session = davey.DaveSession(self.dave_protocol_version, self.user.id, self.voice_client.channel.id)
if self.dave_session is not None:
await self.voice_client.ws.send_binary(
DiscordVoiceWebSocket.MLS_KEY_PACKAGE, self.dave_session.get_serialized_key_package()
)
elif self.dave_session:
self.dave_session.reset()
self.dave_session.set_passthrough_mode(True, 10)
pass
async def _recover_from_invalid_commit(self, transition_id: int) -> None:
payload = {
'op': DiscordVoiceWebSocket.MLS_INVALID_COMMIT_WELCOME,
'd': {
'transition_id': transition_id,
},
}
await self.voice_client.ws.send_as_json(payload)
await self.reinit_dave_session()
async def _execute_transition(self, transition_id: int) -> None:
_log.debug('Executing transition id %d', transition_id)
if transition_id not in self.dave_pending_transitions:
_log.warning("Received execute transition, but we don't have a pending transition for id %d", transition_id)
return
old_version = self.dave_protocol_version
self.dave_protocol_version = self.dave_pending_transitions.pop(transition_id)
if old_version != self.dave_protocol_version and self.dave_protocol_version == 0:
self.dave_downgraded = True
_log.debug('DAVE Session downgraded')
elif transition_id > 0 and self.dave_downgraded:
self.dave_downgraded = False
if self.dave_session:
self.dave_session.set_passthrough_mode(True, 10)
_log.debug('DAVE Session upgraded')
# In the future, the session should be signaled too, but for now theres just v1
_log.debug('Transition id %d executed', transition_id)
async def voice_state_update(self, data: GuildVoiceStatePayload) -> None:
channel_id = data['channel_id']

5
pyproject.toml

@ -36,7 +36,10 @@ Documentation = "https://discordpy.readthedocs.io/en/latest/"
dependencies = { file = "requirements.txt" }
[project.optional-dependencies]
voice = ["PyNaCl>=1.5.0,<1.6"]
voice = [
"PyNaCl>=1.5.0,<1.6",
"davey==0.1.0"
]
docs = [
"sphinx==4.4.0",
"sphinxcontrib_trio==1.1.2",

Loading…
Cancel
Save