Browse Source

Type-hint gateway

pull/7494/head
Josh 3 years ago
committed by Rapptz
parent
commit
c8064ba6f2
  1. 308
      discord/gateway.py
  2. 2
      discord/http.py
  3. 1
      discord/voice_client.py

308
discord/gateway.py

@ -21,9 +21,10 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import asyncio import asyncio
from collections import namedtuple, deque from collections import deque
import concurrent.futures import concurrent.futures
import logging import logging
import struct import struct
@ -33,6 +34,8 @@ import threading
import traceback import traceback
import zlib import zlib
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Type
import aiohttp import aiohttp
from . import utils from . import utils
@ -50,6 +53,11 @@ __all__ = (
'ReconnectWebSocket', 'ReconnectWebSocket',
) )
if TYPE_CHECKING:
from .client import Client
from .state import ConnectionState
from .voice_client import VoiceClient
class ReconnectWebSocket(Exception): class ReconnectWebSocket(Exception):
"""Signals to safely reconnect the websocket.""" """Signals to safely reconnect the websocket."""
@ -66,26 +74,30 @@ class WebSocketClosure(Exception):
pass pass
EventListener = namedtuple('EventListener', 'predicate event result future') class EventListener(NamedTuple):
predicate: Callable[[Dict[str, Any]], bool]
event: str
result: Optional[Callable[[Dict[str, Any]], Any]]
future: asyncio.Future[Any]
class GatewayRatelimiter: class GatewayRatelimiter:
def __init__(self, count=110, per=60.0): def __init__(self, count: int = 110, per: float = 60.0) -> None:
# The default is 110 to give room for at least 10 heartbeats per minute # The default is 110 to give room for at least 10 heartbeats per minute
self.max = count self.max: int = count
self.remaining = count self.remaining: int = count
self.window = 0.0 self.window: float = 0.0
self.per = per self.per: float = per
self.lock = asyncio.Lock() self.lock: asyncio.Lock = asyncio.Lock()
self.shard_id = None self.shard_id: Optional[int] = None
def is_ratelimited(self): def is_ratelimited(self) -> bool:
current = time.time() current = time.time()
if current > self.window + self.per: if current > self.window + self.per:
return False return False
return self.remaining == 0 return self.remaining == 0
def get_delay(self): def get_delay(self) -> float:
current = time.time() current = time.time()
if current > self.window + self.per: if current > self.window + self.per:
@ -103,7 +115,7 @@ class GatewayRatelimiter:
return 0.0 return 0.0
async def block(self): async def block(self) -> None:
async with self.lock: async with self.lock:
delta = self.get_delay() delta = self.get_delay()
if delta: if delta:
@ -112,27 +124,31 @@ class GatewayRatelimiter:
class KeepAliveHandler(threading.Thread): class KeepAliveHandler(threading.Thread):
def __init__(self, *args, **kwargs): def __init__(
ws = kwargs.pop('ws', None) self,
interval = kwargs.pop('interval', None) *args: Any,
shard_id = kwargs.pop('shard_id', None) ws: DiscordWebSocket,
threading.Thread.__init__(self, *args, **kwargs) interval: Optional[float] = None,
self.ws = ws shard_id: Optional[int] = None,
self._main_thread_id = ws.thread_id **kwargs: Any,
self.interval = interval ) -> None:
self.daemon = True super().__init__(*args, **kwargs)
self.shard_id = shard_id self.ws: DiscordWebSocket = ws
self.msg = 'Keeping shard ID %s websocket alive with sequence %s.' self._main_thread_id: int = ws.thread_id
self.block_msg = 'Shard ID %s heartbeat blocked for more than %s seconds.' self.interval: Optional[float] = interval
self.behind_msg = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.' self.daemon: bool = True
self._stop_ev = threading.Event() self.shard_id: Optional[int] = shard_id
self._last_ack = time.perf_counter() self.msg: str = 'Keeping shard ID %s websocket alive with sequence %s.'
self._last_send = time.perf_counter() self.block_msg: str = 'Shard ID %s heartbeat blocked for more than %s seconds.'
self._last_recv = time.perf_counter() self.behind_msg: str = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.'
self.latency = float('inf') self._stop_ev: threading.Event = threading.Event()
self.heartbeat_timeout = ws._max_heartbeat_timeout self._last_ack: float = time.perf_counter()
self._last_send: float = time.perf_counter()
def run(self): self._last_recv: float = time.perf_counter()
self.latency: float = float('inf')
self.heartbeat_timeout: float = ws._max_heartbeat_timeout
def run(self) -> None:
while not self._stop_ev.wait(self.interval): while not self._stop_ev.wait(self.interval):
if self._last_recv + self.heartbeat_timeout < time.perf_counter(): if self._last_recv + self.heartbeat_timeout < time.perf_counter():
_log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id) _log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
@ -174,19 +190,19 @@ class KeepAliveHandler(threading.Thread):
else: else:
self._last_send = time.perf_counter() self._last_send = time.perf_counter()
def get_payload(self): def get_payload(self) -> Dict[str, Any]:
return { return {
'op': self.ws.HEARTBEAT, 'op': self.ws.HEARTBEAT,
'd': self.ws.sequence, 'd': self.ws.sequence,
} }
def stop(self): def stop(self) -> None:
self._stop_ev.set() self._stop_ev.set()
def tick(self): def tick(self) -> None:
self._last_recv = time.perf_counter() self._last_recv = time.perf_counter()
def ack(self): def ack(self) -> None:
ack_time = time.perf_counter() ack_time = time.perf_counter()
self._last_ack = ack_time self._last_ack = ack_time
self.latency = ack_time - self._last_send self.latency = ack_time - self._last_send
@ -195,20 +211,20 @@ class KeepAliveHandler(threading.Thread):
class VoiceKeepAliveHandler(KeepAliveHandler): class VoiceKeepAliveHandler(KeepAliveHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.recent_ack_latencies = deque(maxlen=20) self.recent_ack_latencies: Deque[float] = deque(maxlen=20)
self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.' self.msg: str = 'Keeping shard ID %s voice websocket alive with timestamp %s.'
self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds' self.block_msg: str = 'Shard ID %s voice heartbeat blocked for more than %s seconds'
self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind' self.behind_msg: str = 'High socket latency, shard ID %s heartbeat is %.1fs behind'
def get_payload(self): def get_payload(self) -> Dict[str, Any]:
return { return {
'op': self.ws.HEARTBEAT, 'op': self.ws.HEARTBEAT,
'd': int(time.time() * 1000), 'd': int(time.time() * 1000),
} }
def ack(self): def ack(self) -> None:
ack_time = time.perf_counter() ack_time = time.perf_counter()
self._last_ack = ack_time self._last_ack = ack_time
self._last_recv = ack_time self._last_recv = ack_time
@ -221,6 +237,9 @@ class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse):
return await super().close(code=code, message=message) return await super().close(code=code, message=message)
DWS = TypeVar('DWS', bound='DiscordWebSocket')
class DiscordWebSocket: class DiscordWebSocket:
"""Implements a WebSocket for Discord's gateway v6. """Implements a WebSocket for Discord's gateway v6.
@ -261,6 +280,17 @@ class DiscordWebSocket:
The authentication token for discord. The authentication token for discord.
""" """
if TYPE_CHECKING:
token: Optional[str]
_connection: ConnectionState
_discord_parsers: Dict[str, Callable[..., Any]]
call_hooks: Callable[..., Any]
_initial_identify: bool
shard_id: Optional[int]
shard_count: Optional[int]
gateway: str
_max_heartbeat_timeout: float
# fmt: off # fmt: off
DISPATCH = 0 DISPATCH = 0
HEARTBEAT = 1 HEARTBEAT = 1
@ -277,51 +307,51 @@ class DiscordWebSocket:
GUILD_SYNC = 12 GUILD_SYNC = 12
# fmt: on # fmt: on
def __init__(self, socket, *, loop): def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None:
self.socket = socket self.socket: aiohttp.ClientWebSocketResponse = socket
self.loop = loop self.loop: asyncio.AbstractEventLoop = loop
# an empty dispatcher to prevent crashes # an empty dispatcher to prevent crashes
self._dispatch = lambda *args: None self._dispatch: Callable[..., Any] = lambda *args: None
# generic event listeners # generic event listeners
self._dispatch_listeners = [] self._dispatch_listeners: List[EventListener] = []
# the keep alive # the keep alive
self._keep_alive = None self._keep_alive: Optional[KeepAliveHandler] = None
self.thread_id = threading.get_ident() self.thread_id: int = threading.get_ident()
# ws related stuff # ws related stuff
self.session_id = None self.session_id: Optional[str] = None
self.sequence = None self.sequence: Optional[int] = None
self._zlib = zlib.decompressobj() self._zlib: zlib._Decompress = zlib.decompressobj()
self._buffer = bytearray() self._buffer: bytearray = bytearray()
self._close_code = None self._close_code: Optional[int] = None
self._rate_limiter = GatewayRatelimiter() self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
@property @property
def open(self): def open(self) -> bool:
return not self.socket.closed return not self.socket.closed
def is_ratelimited(self): def is_ratelimited(self) -> bool:
return self._rate_limiter.is_ratelimited() return self._rate_limiter.is_ratelimited()
def debug_log_receive(self, data, /): def debug_log_receive(self, data: Dict[str, Any], /) -> None:
self._dispatch('socket_raw_receive', data) self._dispatch('socket_raw_receive', data)
def log_receive(self, _, /): def log_receive(self, _: Dict[str, Any], /) -> None:
pass pass
@classmethod @classmethod
async def from_client( async def from_client(
cls, cls: Type[DWS],
client, client: Client,
*, *,
initial=False, initial: bool = False,
gateway=None, gateway: Optional[str] = None,
shard_id=None, shard_id: Optional[int] = None,
session=None, session: Optional[str] = None,
sequence=None, sequence: Optional[int] = None,
resume=False, resume: bool = False,
): ) -> DWS:
"""Creates a main websocket for Discord from a :class:`Client`. """Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only. This is for internal use only.
@ -363,7 +393,12 @@ class DiscordWebSocket:
await ws.resume() await ws.resume()
return ws return ws
def wait_for(self, event, predicate, result=None): def wait_for(
self,
event: str,
predicate: Callable[[Dict[str, Any]], bool],
result: Optional[Callable[[Dict[str, Any]], Any]] = None,
) -> asyncio.Future[Any]:
"""Waits for a DISPATCH'd event that meets the predicate. """Waits for a DISPATCH'd event that meets the predicate.
Parameters Parameters
@ -388,7 +423,7 @@ class DiscordWebSocket:
self._dispatch_listeners.append(entry) self._dispatch_listeners.append(entry)
return future return future
async def identify(self): async def identify(self) -> None:
"""Sends the IDENTIFY packet.""" """Sends the IDENTIFY packet."""
payload = { payload = {
'op': self.IDENTIFY, 'op': self.IDENTIFY,
@ -426,7 +461,7 @@ class DiscordWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
_log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id) _log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
async def resume(self): async def resume(self) -> None:
"""Sends the RESUME packet.""" """Sends the RESUME packet."""
payload = { payload = {
'op': self.RESUME, 'op': self.RESUME,
@ -440,7 +475,7 @@ class DiscordWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
_log.info('Shard ID %s has sent the RESUME payload.', self.shard_id) _log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
async def received_message(self, msg, /): async def received_message(self, msg: Any, /) -> None:
if type(msg) is bytes: if type(msg) is bytes:
self._buffer.extend(msg) self._buffer.extend(msg)
@ -566,16 +601,16 @@ class DiscordWebSocket:
del self._dispatch_listeners[index] del self._dispatch_listeners[index]
@property @property
def latency(self): def latency(self) -> float:
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.""" """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds."""
heartbeat = self._keep_alive heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency return float('inf') if heartbeat is None else heartbeat.latency
def _can_handle_close(self): def _can_handle_close(self) -> bool:
code = self._close_code or self.socket.close_code code = self._close_code or self.socket.close_code
return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014) return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014)
async def poll_event(self): async def poll_event(self) -> None:
"""Polls for a DISPATCH event and handles the general gateway loop. """Polls for a DISPATCH event and handles the general gateway loop.
Raises Raises
@ -613,23 +648,23 @@ class DiscordWebSocket:
_log.info('Websocket closed with %s, cannot reconnect.', code) _log.info('Websocket closed with %s, cannot reconnect.', code)
raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None
async def debug_send(self, data, /): async def debug_send(self, data: str, /) -> None:
await self._rate_limiter.block() await self._rate_limiter.block()
self._dispatch('socket_raw_send', data) self._dispatch('socket_raw_send', data)
await self.socket.send_str(data) await self.socket.send_str(data)
async def send(self, data, /): async def send(self, data: str, /) -> None:
await self._rate_limiter.block() await self._rate_limiter.block()
await self.socket.send_str(data) await self.socket.send_str(data)
async def send_as_json(self, data): async def send_as_json(self, data: Any) -> None:
try: try:
await self.send(utils._to_json(data)) await self.send(utils._to_json(data))
except RuntimeError as exc: except RuntimeError as exc:
if not self._can_handle_close(): if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def send_heartbeat(self, data): async def send_heartbeat(self, data: Any) -> None:
# This bypasses the rate limit handling code since it has a higher priority # This bypasses the rate limit handling code since it has a higher priority
try: try:
await self.socket.send_str(utils._to_json(data)) await self.socket.send_str(utils._to_json(data))
@ -637,13 +672,19 @@ class DiscordWebSocket:
if not self._can_handle_close(): if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def change_presence(self, *, activity=None, status=None, since=0.0): async def change_presence(
self,
*,
activity: Optional[BaseActivity] = None,
status: Optional[str] = None,
since: float = 0.0,
) -> None:
if activity is not None: if activity is not None:
if not isinstance(activity, BaseActivity): if not isinstance(activity, BaseActivity):
raise InvalidArgument('activity must derive from BaseActivity.') raise InvalidArgument('activity must derive from BaseActivity.')
activity = [activity.to_dict()] activities = [activity.to_dict()]
else: else:
activity = [] activities = []
if status == 'idle': if status == 'idle':
since = int(time.time() * 1000) since = int(time.time() * 1000)
@ -651,7 +692,7 @@ class DiscordWebSocket:
payload = { payload = {
'op': self.PRESENCE, 'op': self.PRESENCE,
'd': { 'd': {
'activities': activity, 'activities': activities,
'afk': False, 'afk': False,
'since': since, 'since': since,
'status': status, 'status': status,
@ -662,7 +703,16 @@ class DiscordWebSocket:
_log.debug('Sending "%s" to change status', sent) _log.debug('Sending "%s" to change status', sent)
await self.send(sent) await self.send(sent)
async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None): async def request_chunks(
self,
guild_id: int,
query: Optional[str] = None,
*,
limit: int,
user_ids: Optional[List[int]] = None,
presences: bool = False,
nonce: Optional[str] = None,
) -> None:
payload = { payload = {
'op': self.REQUEST_MEMBERS, 'op': self.REQUEST_MEMBERS,
'd': { 'd': {
@ -683,7 +733,13 @@ class DiscordWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): async def voice_state(
self,
guild_id: int,
channel_id: int,
self_mute: bool = False,
self_deaf: bool = False,
) -> None:
payload = { payload = {
'op': self.VOICE_STATE, 'op': self.VOICE_STATE,
'd': { 'd': {
@ -697,7 +753,7 @@ class DiscordWebSocket:
_log.debug('Updating our voice state to %s.', payload) _log.debug('Updating our voice state to %s.', payload)
await self.send_as_json(payload) await self.send_as_json(payload)
async def close(self, code=4000): async def close(self, code: int = 4000) -> None:
if self._keep_alive: if self._keep_alive:
self._keep_alive.stop() self._keep_alive.stop()
self._keep_alive = None self._keep_alive = None
@ -706,6 +762,9 @@ class DiscordWebSocket:
await self.socket.close(code=code) await self.socket.close(code=code)
DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket')
class DiscordVoiceWebSocket: class DiscordVoiceWebSocket:
"""Implements the websocket protocol for handling voice connections. """Implements the websocket protocol for handling voice connections.
@ -737,6 +796,12 @@ class DiscordVoiceWebSocket:
Receive only. Indicates a user has disconnected from voice. Receive only. Indicates a user has disconnected from voice.
""" """
if TYPE_CHECKING:
thread_id: int
_connection: VoiceClient
gateway: str
_max_heartbeat_timeout: float
# fmt: off # fmt: off
IDENTIFY = 0 IDENTIFY = 0
SELECT_PROTOCOL = 1 SELECT_PROTOCOL = 1
@ -752,25 +817,31 @@ class DiscordVoiceWebSocket:
CLIENT_DISCONNECT = 13 CLIENT_DISCONNECT = 13
# fmt: on # fmt: on
def __init__(self, socket, loop, *, hook=None): def __init__(
self,
socket: aiohttp.ClientWebSocketResponse,
loop: asyncio.AbstractEventLoop,
*,
hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
) -> None:
self.ws = socket self.ws = socket
self.loop = loop self.loop = loop
self._keep_alive = None self._keep_alive = None
self._close_code = None self._close_code = None
self.secret_key = None self.secret_key = None
if hook: if hook:
self._hook = hook self._hook = hook # type: ignore - type-checker doesn't like overriding methods
async def _hook(self, *args): async def _hook(self, *args: Any) -> None:
pass pass
async def send_as_json(self, data): async def send_as_json(self, data: Any) -> None:
_log.debug('Sending voice websocket frame: %s.', data) _log.debug('Sending voice websocket frame: %s.', data)
await self.ws.send_str(utils._to_json(data)) await self.ws.send_str(utils._to_json(data))
send_heartbeat = send_as_json send_heartbeat = send_as_json
async def resume(self): async def resume(self) -> None:
state = self._connection state = self._connection
payload = { payload = {
'op': self.RESUME, 'op': self.RESUME,
@ -782,7 +853,7 @@ class DiscordVoiceWebSocket:
} }
await self.send_as_json(payload) await self.send_as_json(payload)
async def identify(self): async def identify(self) -> None:
state = self._connection state = self._connection
payload = { payload = {
'op': self.IDENTIFY, 'op': self.IDENTIFY,
@ -796,7 +867,7 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
@classmethod @classmethod
async def from_client(cls, client, *, resume=False, hook=None): async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume=False, hook=None) -> DVWS:
"""Creates a voice websocket for the :class:`VoiceClient`.""" """Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4' gateway = 'wss://' + client.endpoint + '/?v=4'
http = client._state.http http = client._state.http
@ -814,7 +885,7 @@ class DiscordVoiceWebSocket:
return ws return ws
async def select_protocol(self, ip, port, mode): async def select_protocol(self, ip: str, port: int, mode: int) -> None:
payload = { payload = {
'op': self.SELECT_PROTOCOL, 'op': self.SELECT_PROTOCOL,
'd': { 'd': {
@ -829,7 +900,7 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
async def client_connect(self): async def client_connect(self) -> None:
payload = { payload = {
'op': self.CLIENT_CONNECT, 'op': self.CLIENT_CONNECT,
'd': { 'd': {
@ -839,7 +910,7 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
async def speak(self, state=SpeakingState.voice): async def speak(self, state: SpeakingState = SpeakingState.voice) -> None:
payload = { payload = {
'op': self.SPEAKING, 'op': self.SPEAKING,
'd': { 'd': {
@ -850,28 +921,29 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
async def received_message(self, msg): async def received_message(self, msg: Dict[str, Any]) -> None:
_log.debug('Voice websocket frame received: %s', msg) _log.debug('Voice websocket frame received: %s', msg)
op = msg['op'] op = msg['op']
data = msg.get('d') data = msg.get('d')
if op == self.READY: if op == self.READY:
await self.initial_connection(data) await self.initial_connection(data) # type: ignore - type-checker thinks data could be None
elif op == self.HEARTBEAT_ACK: elif op == self.HEARTBEAT_ACK:
self._keep_alive.ack() self._keep_alive.ack() # type: ignore - _keep_alive can't be None at this point
elif op == self.RESUMED: elif op == self.RESUMED:
_log.info('Voice RESUME succeeded.') _log.info('Voice RESUME succeeded.')
elif op == self.SESSION_DESCRIPTION: elif op == self.SESSION_DESCRIPTION:
self._connection.mode = data['mode'] # type-checker thinks data could be None
await self.load_secret_key(data) self._connection.mode = data['mode'] # type: ignore
await self.load_secret_key(data) # type: ignore
elif op == self.HELLO: elif op == self.HELLO:
interval = data['heartbeat_interval'] / 1000.0 interval = data['heartbeat_interval'] / 1000.0 # type: ignore - type-checker thinks data could be None
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0)) self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
self._keep_alive.start() self._keep_alive.start()
await self._hook(self, msg) await self._hook(self, msg)
async def initial_connection(self, data): async def initial_connection(self, data: Dict[str, Any]) -> None:
state = self._connection state = self._connection
state.ssrc = data['ssrc'] state.ssrc = data['ssrc']
state.voice_port = data['port'] state.voice_port = data['port']
@ -888,41 +960,41 @@ class DiscordVoiceWebSocket:
# the ip is ascii starting at the 4th byte and ending at the first null # the ip is ascii starting at the 4th byte and ending at the first null
ip_start = 4 ip_start = 4
ip_end = recv.index(0, ip_start) ip_end = recv.index(0, ip_start)
state.ip = recv[ip_start:ip_end].decode('ascii') state.endpoint_ip = recv[ip_start:ip_end].decode('ascii')
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0] state.voice_port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
_log.debug('detected ip: %s port: %s', state.ip, state.port) _log.debug('detected ip: %s port: %s', state.endpoint_ip, state.voice_port)
# there *should* always be at least one supported mode (xsalsa20_poly1305) # there *should* always be at least one supported mode (xsalsa20_poly1305)
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes] modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
_log.debug('received supported encryption modes: %s', ", ".join(modes)) _log.debug('received supported encryption modes: %s', ", ".join(modes))
mode = modes[0] mode = modes[0]
await self.select_protocol(state.ip, state.port, mode) await self.select_protocol(state.endpoint_ip, state.voice_port, mode)
_log.info('selected the voice protocol for use (%s)', mode) _log.info('selected the voice protocol for use (%s)', mode)
@property @property
def latency(self): def latency(self) -> float:
""":class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds.""" """:class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds."""
heartbeat = self._keep_alive heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency return float('inf') if heartbeat is None else heartbeat.latency
@property @property
def average_latency(self): def average_latency(self) -> float:
""":class:`list`: Average of last 20 HEARTBEAT latencies.""" """:class:`float`: Average of last 20 HEARTBEAT latencies."""
heartbeat = self._keep_alive heartbeat = self._keep_alive
if heartbeat is None or not heartbeat.recent_ack_latencies: if heartbeat is None or not heartbeat.recent_ack_latencies:
return float('inf') return float('inf')
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies) return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
async def load_secret_key(self, data): async def load_secret_key(self, data: Dict[str, Any]) -> None:
_log.info('received secret key for voice connection') _log.info('received secret key for voice connection')
self.secret_key = self._connection.secret_key = data.get('secret_key') self.secret_key = self._connection.secret_key = data.get('secret_key') # type: ignore - type-checker thinks secret_key could be None
await self.speak() await self.speak()
await self.speak(False) await self.speak(SpeakingState.none)
async def poll_event(self): async def poll_event(self) -> None:
# This exception is handled up the chain # This exception is handled up the chain
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
if msg.type is aiohttp.WSMsgType.TEXT: if msg.type is aiohttp.WSMsgType.TEXT:
@ -934,7 +1006,7 @@ class DiscordVoiceWebSocket:
_log.debug('Received %s', msg) _log.debug('Received %s', msg)
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
async def close(self, code=1000): async def close(self, code: int = 1000) -> None:
if self._keep_alive is not None: if self._keep_alive is not None:
self._keep_alive.stop() self._keep_alive.stop()

2
discord/http.py

@ -341,7 +341,7 @@ class HTTPClient:
connector=self.connector, ws_response_class=DiscordClientWebSocketResponse connector=self.connector, ws_response_class=DiscordClientWebSocketResponse
) )
async def ws_connect(self, url: str, *, compress: int = 0) -> Any: async def ws_connect(self, url: str, *, compress: int = 0) -> aiohttp.ClientWebSocketResponse:
kwargs = { kwargs = {
'proxy_auth': self.proxy_auth, 'proxy_auth': self.proxy_auth,
'proxy': self.proxy, 'proxy': self.proxy,

1
discord/voice_client.py

@ -239,6 +239,7 @@ class VoiceClient(VoiceProtocol):
super().__init__(client, channel) super().__init__(client, channel)
state = client._connection state = client._connection
self.token: str = MISSING self.token: str = MISSING
self.server_id: int = MISSING
self.socket = MISSING self.socket = MISSING
self.loop: asyncio.AbstractEventLoop = state.loop self.loop: asyncio.AbstractEventLoop = state.loop
self._state: ConnectionState = state self._state: ConnectionState = state

Loading…
Cancel
Save