|
|
@ -21,9 +21,10 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
|
|
|
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
|
|
|
DEALINGS IN THE SOFTWARE. |
|
|
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
|
|
|
|
import asyncio |
|
|
|
from collections import namedtuple, deque |
|
|
|
from collections import deque |
|
|
|
import concurrent.futures |
|
|
|
import logging |
|
|
|
import struct |
|
|
@ -33,6 +34,8 @@ import threading |
|
|
|
import traceback |
|
|
|
import zlib |
|
|
|
|
|
|
|
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Type |
|
|
|
|
|
|
|
import aiohttp |
|
|
|
|
|
|
|
from . import utils |
|
|
@ -50,6 +53,11 @@ __all__ = ( |
|
|
|
'ReconnectWebSocket', |
|
|
|
) |
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
from .client import Client |
|
|
|
from .state import ConnectionState |
|
|
|
from .voice_client import VoiceClient |
|
|
|
|
|
|
|
|
|
|
|
class ReconnectWebSocket(Exception): |
|
|
|
"""Signals to safely reconnect the websocket.""" |
|
|
@ -66,26 +74,30 @@ class WebSocketClosure(Exception): |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
EventListener = namedtuple('EventListener', 'predicate event result future') |
|
|
|
class EventListener(NamedTuple): |
|
|
|
predicate: Callable[[Dict[str, Any]], bool] |
|
|
|
event: str |
|
|
|
result: Optional[Callable[[Dict[str, Any]], Any]] |
|
|
|
future: asyncio.Future[Any] |
|
|
|
|
|
|
|
|
|
|
|
class GatewayRatelimiter: |
|
|
|
def __init__(self, count=110, per=60.0): |
|
|
|
def __init__(self, count: int = 110, per: float = 60.0) -> None: |
|
|
|
# The default is 110 to give room for at least 10 heartbeats per minute |
|
|
|
self.max = count |
|
|
|
self.remaining = count |
|
|
|
self.window = 0.0 |
|
|
|
self.per = per |
|
|
|
self.lock = asyncio.Lock() |
|
|
|
self.shard_id = None |
|
|
|
|
|
|
|
def is_ratelimited(self): |
|
|
|
self.max: int = count |
|
|
|
self.remaining: int = count |
|
|
|
self.window: float = 0.0 |
|
|
|
self.per: float = per |
|
|
|
self.lock: asyncio.Lock = asyncio.Lock() |
|
|
|
self.shard_id: Optional[int] = None |
|
|
|
|
|
|
|
def is_ratelimited(self) -> bool: |
|
|
|
current = time.time() |
|
|
|
if current > self.window + self.per: |
|
|
|
return False |
|
|
|
return self.remaining == 0 |
|
|
|
|
|
|
|
def get_delay(self): |
|
|
|
def get_delay(self) -> float: |
|
|
|
current = time.time() |
|
|
|
|
|
|
|
if current > self.window + self.per: |
|
|
@ -103,7 +115,7 @@ class GatewayRatelimiter: |
|
|
|
|
|
|
|
return 0.0 |
|
|
|
|
|
|
|
async def block(self): |
|
|
|
async def block(self) -> None: |
|
|
|
async with self.lock: |
|
|
|
delta = self.get_delay() |
|
|
|
if delta: |
|
|
@ -112,27 +124,31 @@ class GatewayRatelimiter: |
|
|
|
|
|
|
|
|
|
|
|
class KeepAliveHandler(threading.Thread): |
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
ws = kwargs.pop('ws', None) |
|
|
|
interval = kwargs.pop('interval', None) |
|
|
|
shard_id = kwargs.pop('shard_id', None) |
|
|
|
threading.Thread.__init__(self, *args, **kwargs) |
|
|
|
self.ws = ws |
|
|
|
self._main_thread_id = ws.thread_id |
|
|
|
self.interval = interval |
|
|
|
self.daemon = True |
|
|
|
self.shard_id = shard_id |
|
|
|
self.msg = 'Keeping shard ID %s websocket alive with sequence %s.' |
|
|
|
self.block_msg = 'Shard ID %s heartbeat blocked for more than %s seconds.' |
|
|
|
self.behind_msg = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.' |
|
|
|
self._stop_ev = threading.Event() |
|
|
|
self._last_ack = time.perf_counter() |
|
|
|
self._last_send = time.perf_counter() |
|
|
|
self._last_recv = time.perf_counter() |
|
|
|
self.latency = float('inf') |
|
|
|
self.heartbeat_timeout = ws._max_heartbeat_timeout |
|
|
|
|
|
|
|
def run(self): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
*args: Any, |
|
|
|
ws: DiscordWebSocket, |
|
|
|
interval: Optional[float] = None, |
|
|
|
shard_id: Optional[int] = None, |
|
|
|
**kwargs: Any, |
|
|
|
) -> None: |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
self.ws: DiscordWebSocket = ws |
|
|
|
self._main_thread_id: int = ws.thread_id |
|
|
|
self.interval: Optional[float] = interval |
|
|
|
self.daemon: bool = True |
|
|
|
self.shard_id: Optional[int] = shard_id |
|
|
|
self.msg: str = 'Keeping shard ID %s websocket alive with sequence %s.' |
|
|
|
self.block_msg: str = 'Shard ID %s heartbeat blocked for more than %s seconds.' |
|
|
|
self.behind_msg: str = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.' |
|
|
|
self._stop_ev: threading.Event = threading.Event() |
|
|
|
self._last_ack: float = time.perf_counter() |
|
|
|
self._last_send: float = time.perf_counter() |
|
|
|
self._last_recv: float = time.perf_counter() |
|
|
|
self.latency: float = float('inf') |
|
|
|
self.heartbeat_timeout: float = ws._max_heartbeat_timeout |
|
|
|
|
|
|
|
def run(self) -> None: |
|
|
|
while not self._stop_ev.wait(self.interval): |
|
|
|
if self._last_recv + self.heartbeat_timeout < time.perf_counter(): |
|
|
|
_log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id) |
|
|
@ -174,19 +190,19 @@ class KeepAliveHandler(threading.Thread): |
|
|
|
else: |
|
|
|
self._last_send = time.perf_counter() |
|
|
|
|
|
|
|
def get_payload(self): |
|
|
|
def get_payload(self) -> Dict[str, Any]: |
|
|
|
return { |
|
|
|
'op': self.ws.HEARTBEAT, |
|
|
|
'd': self.ws.sequence, |
|
|
|
} |
|
|
|
|
|
|
|
def stop(self): |
|
|
|
def stop(self) -> None: |
|
|
|
self._stop_ev.set() |
|
|
|
|
|
|
|
def tick(self): |
|
|
|
def tick(self) -> None: |
|
|
|
self._last_recv = time.perf_counter() |
|
|
|
|
|
|
|
def ack(self): |
|
|
|
def ack(self) -> None: |
|
|
|
ack_time = time.perf_counter() |
|
|
|
self._last_ack = ack_time |
|
|
|
self.latency = ack_time - self._last_send |
|
|
@ -195,20 +211,20 @@ class KeepAliveHandler(threading.Thread): |
|
|
|
|
|
|
|
|
|
|
|
class VoiceKeepAliveHandler(KeepAliveHandler): |
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None: |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
self.recent_ack_latencies = deque(maxlen=20) |
|
|
|
self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.' |
|
|
|
self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds' |
|
|
|
self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind' |
|
|
|
self.recent_ack_latencies: Deque[float] = deque(maxlen=20) |
|
|
|
self.msg: str = 'Keeping shard ID %s voice websocket alive with timestamp %s.' |
|
|
|
self.block_msg: str = 'Shard ID %s voice heartbeat blocked for more than %s seconds' |
|
|
|
self.behind_msg: str = 'High socket latency, shard ID %s heartbeat is %.1fs behind' |
|
|
|
|
|
|
|
def get_payload(self): |
|
|
|
def get_payload(self) -> Dict[str, Any]: |
|
|
|
return { |
|
|
|
'op': self.ws.HEARTBEAT, |
|
|
|
'd': int(time.time() * 1000), |
|
|
|
} |
|
|
|
|
|
|
|
def ack(self): |
|
|
|
def ack(self) -> None: |
|
|
|
ack_time = time.perf_counter() |
|
|
|
self._last_ack = ack_time |
|
|
|
self._last_recv = ack_time |
|
|
@ -221,6 +237,9 @@ class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse): |
|
|
|
return await super().close(code=code, message=message) |
|
|
|
|
|
|
|
|
|
|
|
DWS = TypeVar('DWS', bound='DiscordWebSocket') |
|
|
|
|
|
|
|
|
|
|
|
class DiscordWebSocket: |
|
|
|
"""Implements a WebSocket for Discord's gateway v6. |
|
|
|
|
|
|
@ -261,6 +280,17 @@ class DiscordWebSocket: |
|
|
|
The authentication token for discord. |
|
|
|
""" |
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
token: Optional[str] |
|
|
|
_connection: ConnectionState |
|
|
|
_discord_parsers: Dict[str, Callable[..., Any]] |
|
|
|
call_hooks: Callable[..., Any] |
|
|
|
_initial_identify: bool |
|
|
|
shard_id: Optional[int] |
|
|
|
shard_count: Optional[int] |
|
|
|
gateway: str |
|
|
|
_max_heartbeat_timeout: float |
|
|
|
|
|
|
|
# fmt: off |
|
|
|
DISPATCH = 0 |
|
|
|
HEARTBEAT = 1 |
|
|
@ -277,51 +307,51 @@ class DiscordWebSocket: |
|
|
|
GUILD_SYNC = 12 |
|
|
|
# fmt: on |
|
|
|
|
|
|
|
def __init__(self, socket, *, loop): |
|
|
|
self.socket = socket |
|
|
|
self.loop = loop |
|
|
|
def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
self.socket: aiohttp.ClientWebSocketResponse = socket |
|
|
|
self.loop: asyncio.AbstractEventLoop = loop |
|
|
|
|
|
|
|
# an empty dispatcher to prevent crashes |
|
|
|
self._dispatch = lambda *args: None |
|
|
|
self._dispatch: Callable[..., Any] = lambda *args: None |
|
|
|
# generic event listeners |
|
|
|
self._dispatch_listeners = [] |
|
|
|
self._dispatch_listeners: List[EventListener] = [] |
|
|
|
# the keep alive |
|
|
|
self._keep_alive = None |
|
|
|
self.thread_id = threading.get_ident() |
|
|
|
self._keep_alive: Optional[KeepAliveHandler] = None |
|
|
|
self.thread_id: int = threading.get_ident() |
|
|
|
|
|
|
|
# ws related stuff |
|
|
|
self.session_id = None |
|
|
|
self.sequence = None |
|
|
|
self._zlib = zlib.decompressobj() |
|
|
|
self._buffer = bytearray() |
|
|
|
self._close_code = None |
|
|
|
self._rate_limiter = GatewayRatelimiter() |
|
|
|
self.session_id: Optional[str] = None |
|
|
|
self.sequence: Optional[int] = None |
|
|
|
self._zlib: zlib._Decompress = zlib.decompressobj() |
|
|
|
self._buffer: bytearray = bytearray() |
|
|
|
self._close_code: Optional[int] = None |
|
|
|
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter() |
|
|
|
|
|
|
|
@property |
|
|
|
def open(self): |
|
|
|
def open(self) -> bool: |
|
|
|
return not self.socket.closed |
|
|
|
|
|
|
|
def is_ratelimited(self): |
|
|
|
def is_ratelimited(self) -> bool: |
|
|
|
return self._rate_limiter.is_ratelimited() |
|
|
|
|
|
|
|
def debug_log_receive(self, data, /): |
|
|
|
def debug_log_receive(self, data: Dict[str, Any], /) -> None: |
|
|
|
self._dispatch('socket_raw_receive', data) |
|
|
|
|
|
|
|
def log_receive(self, _, /): |
|
|
|
def log_receive(self, _: Dict[str, Any], /) -> None: |
|
|
|
pass |
|
|
|
|
|
|
|
@classmethod |
|
|
|
async def from_client( |
|
|
|
cls, |
|
|
|
client, |
|
|
|
cls: Type[DWS], |
|
|
|
client: Client, |
|
|
|
*, |
|
|
|
initial=False, |
|
|
|
gateway=None, |
|
|
|
shard_id=None, |
|
|
|
session=None, |
|
|
|
sequence=None, |
|
|
|
resume=False, |
|
|
|
): |
|
|
|
initial: bool = False, |
|
|
|
gateway: Optional[str] = None, |
|
|
|
shard_id: Optional[int] = None, |
|
|
|
session: Optional[str] = None, |
|
|
|
sequence: Optional[int] = None, |
|
|
|
resume: bool = False, |
|
|
|
) -> DWS: |
|
|
|
"""Creates a main websocket for Discord from a :class:`Client`. |
|
|
|
|
|
|
|
This is for internal use only. |
|
|
@ -363,7 +393,12 @@ class DiscordWebSocket: |
|
|
|
await ws.resume() |
|
|
|
return ws |
|
|
|
|
|
|
|
def wait_for(self, event, predicate, result=None): |
|
|
|
def wait_for( |
|
|
|
self, |
|
|
|
event: str, |
|
|
|
predicate: Callable[[Dict[str, Any]], bool], |
|
|
|
result: Optional[Callable[[Dict[str, Any]], Any]] = None, |
|
|
|
) -> asyncio.Future[Any]: |
|
|
|
"""Waits for a DISPATCH'd event that meets the predicate. |
|
|
|
|
|
|
|
Parameters |
|
|
@ -388,7 +423,7 @@ class DiscordWebSocket: |
|
|
|
self._dispatch_listeners.append(entry) |
|
|
|
return future |
|
|
|
|
|
|
|
async def identify(self): |
|
|
|
async def identify(self) -> None: |
|
|
|
"""Sends the IDENTIFY packet.""" |
|
|
|
payload = { |
|
|
|
'op': self.IDENTIFY, |
|
|
@ -426,7 +461,7 @@ class DiscordWebSocket: |
|
|
|
await self.send_as_json(payload) |
|
|
|
_log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id) |
|
|
|
|
|
|
|
async def resume(self): |
|
|
|
async def resume(self) -> None: |
|
|
|
"""Sends the RESUME packet.""" |
|
|
|
payload = { |
|
|
|
'op': self.RESUME, |
|
|
@ -440,7 +475,7 @@ class DiscordWebSocket: |
|
|
|
await self.send_as_json(payload) |
|
|
|
_log.info('Shard ID %s has sent the RESUME payload.', self.shard_id) |
|
|
|
|
|
|
|
async def received_message(self, msg, /): |
|
|
|
async def received_message(self, msg: Any, /) -> None: |
|
|
|
if type(msg) is bytes: |
|
|
|
self._buffer.extend(msg) |
|
|
|
|
|
|
@ -566,16 +601,16 @@ class DiscordWebSocket: |
|
|
|
del self._dispatch_listeners[index] |
|
|
|
|
|
|
|
@property |
|
|
|
def latency(self): |
|
|
|
def latency(self) -> float: |
|
|
|
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.""" |
|
|
|
heartbeat = self._keep_alive |
|
|
|
return float('inf') if heartbeat is None else heartbeat.latency |
|
|
|
|
|
|
|
def _can_handle_close(self): |
|
|
|
def _can_handle_close(self) -> bool: |
|
|
|
code = self._close_code or self.socket.close_code |
|
|
|
return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014) |
|
|
|
|
|
|
|
async def poll_event(self): |
|
|
|
async def poll_event(self) -> None: |
|
|
|
"""Polls for a DISPATCH event and handles the general gateway loop. |
|
|
|
|
|
|
|
Raises |
|
|
@ -613,23 +648,23 @@ class DiscordWebSocket: |
|
|
|
_log.info('Websocket closed with %s, cannot reconnect.', code) |
|
|
|
raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None |
|
|
|
|
|
|
|
async def debug_send(self, data, /): |
|
|
|
async def debug_send(self, data: str, /) -> None: |
|
|
|
await self._rate_limiter.block() |
|
|
|
self._dispatch('socket_raw_send', data) |
|
|
|
await self.socket.send_str(data) |
|
|
|
|
|
|
|
async def send(self, data, /): |
|
|
|
async def send(self, data: str, /) -> None: |
|
|
|
await self._rate_limiter.block() |
|
|
|
await self.socket.send_str(data) |
|
|
|
|
|
|
|
async def send_as_json(self, data): |
|
|
|
async def send_as_json(self, data: Any) -> None: |
|
|
|
try: |
|
|
|
await self.send(utils._to_json(data)) |
|
|
|
except RuntimeError as exc: |
|
|
|
if not self._can_handle_close(): |
|
|
|
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc |
|
|
|
|
|
|
|
async def send_heartbeat(self, data): |
|
|
|
async def send_heartbeat(self, data: Any) -> None: |
|
|
|
# This bypasses the rate limit handling code since it has a higher priority |
|
|
|
try: |
|
|
|
await self.socket.send_str(utils._to_json(data)) |
|
|
@ -637,13 +672,19 @@ class DiscordWebSocket: |
|
|
|
if not self._can_handle_close(): |
|
|
|
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc |
|
|
|
|
|
|
|
async def change_presence(self, *, activity=None, status=None, since=0.0): |
|
|
|
async def change_presence( |
|
|
|
self, |
|
|
|
*, |
|
|
|
activity: Optional[BaseActivity] = None, |
|
|
|
status: Optional[str] = None, |
|
|
|
since: float = 0.0, |
|
|
|
) -> None: |
|
|
|
if activity is not None: |
|
|
|
if not isinstance(activity, BaseActivity): |
|
|
|
raise InvalidArgument('activity must derive from BaseActivity.') |
|
|
|
activity = [activity.to_dict()] |
|
|
|
activities = [activity.to_dict()] |
|
|
|
else: |
|
|
|
activity = [] |
|
|
|
activities = [] |
|
|
|
|
|
|
|
if status == 'idle': |
|
|
|
since = int(time.time() * 1000) |
|
|
@ -651,7 +692,7 @@ class DiscordWebSocket: |
|
|
|
payload = { |
|
|
|
'op': self.PRESENCE, |
|
|
|
'd': { |
|
|
|
'activities': activity, |
|
|
|
'activities': activities, |
|
|
|
'afk': False, |
|
|
|
'since': since, |
|
|
|
'status': status, |
|
|
@ -662,7 +703,16 @@ class DiscordWebSocket: |
|
|
|
_log.debug('Sending "%s" to change status', sent) |
|
|
|
await self.send(sent) |
|
|
|
|
|
|
|
async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None): |
|
|
|
async def request_chunks( |
|
|
|
self, |
|
|
|
guild_id: int, |
|
|
|
query: Optional[str] = None, |
|
|
|
*, |
|
|
|
limit: int, |
|
|
|
user_ids: Optional[List[int]] = None, |
|
|
|
presences: bool = False, |
|
|
|
nonce: Optional[str] = None, |
|
|
|
) -> None: |
|
|
|
payload = { |
|
|
|
'op': self.REQUEST_MEMBERS, |
|
|
|
'd': { |
|
|
@ -683,7 +733,13 @@ class DiscordWebSocket: |
|
|
|
|
|
|
|
await self.send_as_json(payload) |
|
|
|
|
|
|
|
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): |
|
|
|
async def voice_state( |
|
|
|
self, |
|
|
|
guild_id: int, |
|
|
|
channel_id: int, |
|
|
|
self_mute: bool = False, |
|
|
|
self_deaf: bool = False, |
|
|
|
) -> None: |
|
|
|
payload = { |
|
|
|
'op': self.VOICE_STATE, |
|
|
|
'd': { |
|
|
@ -697,7 +753,7 @@ class DiscordWebSocket: |
|
|
|
_log.debug('Updating our voice state to %s.', payload) |
|
|
|
await self.send_as_json(payload) |
|
|
|
|
|
|
|
async def close(self, code=4000): |
|
|
|
async def close(self, code: int = 4000) -> None: |
|
|
|
if self._keep_alive: |
|
|
|
self._keep_alive.stop() |
|
|
|
self._keep_alive = None |
|
|
@ -706,6 +762,9 @@ class DiscordWebSocket: |
|
|
|
await self.socket.close(code=code) |
|
|
|
|
|
|
|
|
|
|
|
DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket') |
|
|
|
|
|
|
|
|
|
|
|
class DiscordVoiceWebSocket: |
|
|
|
"""Implements the websocket protocol for handling voice connections. |
|
|
|
|
|
|
@ -737,6 +796,12 @@ class DiscordVoiceWebSocket: |
|
|
|
Receive only. Indicates a user has disconnected from voice. |
|
|
|
""" |
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
thread_id: int |
|
|
|
_connection: VoiceClient |
|
|
|
gateway: str |
|
|
|
_max_heartbeat_timeout: float |
|
|
|
|
|
|
|
# fmt: off |
|
|
|
IDENTIFY = 0 |
|
|
|
SELECT_PROTOCOL = 1 |
|
|
@ -752,25 +817,31 @@ class DiscordVoiceWebSocket: |
|
|
|
CLIENT_DISCONNECT = 13 |
|
|
|
# fmt: on |
|
|
|
|
|
|
|
def __init__(self, socket, loop, *, hook=None): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
socket: aiohttp.ClientWebSocketResponse, |
|
|
|
loop: asyncio.AbstractEventLoop, |
|
|
|
*, |
|
|
|
hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None, |
|
|
|
) -> None: |
|
|
|
self.ws = socket |
|
|
|
self.loop = loop |
|
|
|
self._keep_alive = None |
|
|
|
self._close_code = None |
|
|
|
self.secret_key = None |
|
|
|
if hook: |
|
|
|
self._hook = hook |
|
|
|
self._hook = hook # type: ignore - type-checker doesn't like overriding methods |
|
|
|
|
|
|
|
async def _hook(self, *args): |
|
|
|
async def _hook(self, *args: Any) -> None: |
|
|
|
pass |
|
|
|
|
|
|
|
async def send_as_json(self, data): |
|
|
|
async def send_as_json(self, data: Any) -> None: |
|
|
|
_log.debug('Sending voice websocket frame: %s.', data) |
|
|
|
await self.ws.send_str(utils._to_json(data)) |
|
|
|
|
|
|
|
send_heartbeat = send_as_json |
|
|
|
|
|
|
|
async def resume(self): |
|
|
|
async def resume(self) -> None: |
|
|
|
state = self._connection |
|
|
|
payload = { |
|
|
|
'op': self.RESUME, |
|
|
@ -782,7 +853,7 @@ class DiscordVoiceWebSocket: |
|
|
|
} |
|
|
|
await self.send_as_json(payload) |
|
|
|
|
|
|
|
async def identify(self): |
|
|
|
async def identify(self) -> None: |
|
|
|
state = self._connection |
|
|
|
payload = { |
|
|
|
'op': self.IDENTIFY, |
|
|
@ -796,7 +867,7 @@ class DiscordVoiceWebSocket: |
|
|
|
await self.send_as_json(payload) |
|
|
|
|
|
|
|
@classmethod |
|
|
|
async def from_client(cls, client, *, resume=False, hook=None): |
|
|
|
async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume=False, hook=None) -> DVWS: |
|
|
|
"""Creates a voice websocket for the :class:`VoiceClient`.""" |
|
|
|
gateway = 'wss://' + client.endpoint + '/?v=4' |
|
|
|
http = client._state.http |
|
|
@ -814,7 +885,7 @@ class DiscordVoiceWebSocket: |
|
|
|
|
|
|
|
return ws |
|
|
|
|
|
|
|
async def select_protocol(self, ip, port, mode): |
|
|
|
async def select_protocol(self, ip: str, port: int, mode: int) -> None: |
|
|
|
payload = { |
|
|
|
'op': self.SELECT_PROTOCOL, |
|
|
|
'd': { |
|
|
@ -829,7 +900,7 @@ class DiscordVoiceWebSocket: |
|
|
|
|
|
|
|
await self.send_as_json(payload) |
|
|
|
|
|
|
|
async def client_connect(self): |
|
|
|
async def client_connect(self) -> None: |
|
|
|
payload = { |
|
|
|
'op': self.CLIENT_CONNECT, |
|
|
|
'd': { |
|
|
@ -839,7 +910,7 @@ class DiscordVoiceWebSocket: |
|
|
|
|
|
|
|
await self.send_as_json(payload) |
|
|
|
|
|
|
|
async def speak(self, state=SpeakingState.voice): |
|
|
|
async def speak(self, state: SpeakingState = SpeakingState.voice) -> None: |
|
|
|
payload = { |
|
|
|
'op': self.SPEAKING, |
|
|
|
'd': { |
|
|
@ -850,28 +921,29 @@ class DiscordVoiceWebSocket: |
|
|
|
|
|
|
|
await self.send_as_json(payload) |
|
|
|
|
|
|
|
async def received_message(self, msg): |
|
|
|
async def received_message(self, msg: Dict[str, Any]) -> None: |
|
|
|
_log.debug('Voice websocket frame received: %s', msg) |
|
|
|
op = msg['op'] |
|
|
|
data = msg.get('d') |
|
|
|
|
|
|
|
if op == self.READY: |
|
|
|
await self.initial_connection(data) |
|
|
|
await self.initial_connection(data) # type: ignore - type-checker thinks data could be None |
|
|
|
elif op == self.HEARTBEAT_ACK: |
|
|
|
self._keep_alive.ack() |
|
|
|
self._keep_alive.ack() # type: ignore - _keep_alive can't be None at this point |
|
|
|
elif op == self.RESUMED: |
|
|
|
_log.info('Voice RESUME succeeded.') |
|
|
|
elif op == self.SESSION_DESCRIPTION: |
|
|
|
self._connection.mode = data['mode'] |
|
|
|
await self.load_secret_key(data) |
|
|
|
# type-checker thinks data could be None |
|
|
|
self._connection.mode = data['mode'] # type: ignore |
|
|
|
await self.load_secret_key(data) # type: ignore |
|
|
|
elif op == self.HELLO: |
|
|
|
interval = data['heartbeat_interval'] / 1000.0 |
|
|
|
interval = data['heartbeat_interval'] / 1000.0 # type: ignore - type-checker thinks data could be None |
|
|
|
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0)) |
|
|
|
self._keep_alive.start() |
|
|
|
|
|
|
|
await self._hook(self, msg) |
|
|
|
|
|
|
|
async def initial_connection(self, data): |
|
|
|
async def initial_connection(self, data: Dict[str, Any]) -> None: |
|
|
|
state = self._connection |
|
|
|
state.ssrc = data['ssrc'] |
|
|
|
state.voice_port = data['port'] |
|
|
@ -888,41 +960,41 @@ class DiscordVoiceWebSocket: |
|
|
|
# the ip is ascii starting at the 4th byte and ending at the first null |
|
|
|
ip_start = 4 |
|
|
|
ip_end = recv.index(0, ip_start) |
|
|
|
state.ip = recv[ip_start:ip_end].decode('ascii') |
|
|
|
state.endpoint_ip = recv[ip_start:ip_end].decode('ascii') |
|
|
|
|
|
|
|
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0] |
|
|
|
_log.debug('detected ip: %s port: %s', state.ip, state.port) |
|
|
|
state.voice_port = struct.unpack_from('>H', recv, len(recv) - 2)[0] |
|
|
|
_log.debug('detected ip: %s port: %s', state.endpoint_ip, state.voice_port) |
|
|
|
|
|
|
|
# there *should* always be at least one supported mode (xsalsa20_poly1305) |
|
|
|
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes] |
|
|
|
_log.debug('received supported encryption modes: %s', ", ".join(modes)) |
|
|
|
|
|
|
|
mode = modes[0] |
|
|
|
await self.select_protocol(state.ip, state.port, mode) |
|
|
|
await self.select_protocol(state.endpoint_ip, state.voice_port, mode) |
|
|
|
_log.info('selected the voice protocol for use (%s)', mode) |
|
|
|
|
|
|
|
@property |
|
|
|
def latency(self): |
|
|
|
def latency(self) -> float: |
|
|
|
""":class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds.""" |
|
|
|
heartbeat = self._keep_alive |
|
|
|
return float('inf') if heartbeat is None else heartbeat.latency |
|
|
|
|
|
|
|
@property |
|
|
|
def average_latency(self): |
|
|
|
""":class:`list`: Average of last 20 HEARTBEAT latencies.""" |
|
|
|
def average_latency(self) -> float: |
|
|
|
""":class:`float`: Average of last 20 HEARTBEAT latencies.""" |
|
|
|
heartbeat = self._keep_alive |
|
|
|
if heartbeat is None or not heartbeat.recent_ack_latencies: |
|
|
|
return float('inf') |
|
|
|
|
|
|
|
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies) |
|
|
|
|
|
|
|
async def load_secret_key(self, data): |
|
|
|
async def load_secret_key(self, data: Dict[str, Any]) -> None: |
|
|
|
_log.info('received secret key for voice connection') |
|
|
|
self.secret_key = self._connection.secret_key = data.get('secret_key') |
|
|
|
self.secret_key = self._connection.secret_key = data.get('secret_key') # type: ignore - type-checker thinks secret_key could be None |
|
|
|
await self.speak() |
|
|
|
await self.speak(False) |
|
|
|
await self.speak(SpeakingState.none) |
|
|
|
|
|
|
|
async def poll_event(self): |
|
|
|
async def poll_event(self) -> None: |
|
|
|
# This exception is handled up the chain |
|
|
|
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) |
|
|
|
if msg.type is aiohttp.WSMsgType.TEXT: |
|
|
@ -934,7 +1006,7 @@ class DiscordVoiceWebSocket: |
|
|
|
_log.debug('Received %s', msg) |
|
|
|
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) |
|
|
|
|
|
|
|
async def close(self, code=1000): |
|
|
|
async def close(self, code: int = 1000) -> None: |
|
|
|
if self._keep_alive is not None: |
|
|
|
self._keep_alive.stop() |
|
|
|
|
|
|
|