From c8064ba6f2998d03a0778d89bf12cdfa548a32e4 Mon Sep 17 00:00:00 2001 From: Josh Date: Tue, 22 Feb 2022 16:44:52 +1000 Subject: [PATCH] Type-hint gateway --- discord/gateway.py | 308 +++++++++++++++++++++++++--------------- discord/http.py | 2 +- discord/voice_client.py | 1 + 3 files changed, 192 insertions(+), 119 deletions(-) diff --git a/discord/gateway.py b/discord/gateway.py index a8f5c0dd0..1f2030f1b 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -21,9 +21,10 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations import asyncio -from collections import namedtuple, deque +from collections import deque import concurrent.futures import logging import struct @@ -33,6 +34,8 @@ import threading import traceback import zlib +from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Type + import aiohttp from . import utils @@ -50,6 +53,11 @@ __all__ = ( 'ReconnectWebSocket', ) +if TYPE_CHECKING: + from .client import Client + from .state import ConnectionState + from .voice_client import VoiceClient + class ReconnectWebSocket(Exception): """Signals to safely reconnect the websocket.""" @@ -66,26 +74,30 @@ class WebSocketClosure(Exception): pass -EventListener = namedtuple('EventListener', 'predicate event result future') +class EventListener(NamedTuple): + predicate: Callable[[Dict[str, Any]], bool] + event: str + result: Optional[Callable[[Dict[str, Any]], Any]] + future: asyncio.Future[Any] class GatewayRatelimiter: - def __init__(self, count=110, per=60.0): + def __init__(self, count: int = 110, per: float = 60.0) -> None: # The default is 110 to give room for at least 10 heartbeats per minute - self.max = count - self.remaining = count - self.window = 0.0 - self.per = per - self.lock = asyncio.Lock() - self.shard_id = None - - def is_ratelimited(self): + self.max: int = count + self.remaining: int = count + self.window: float = 0.0 + self.per: float = per + self.lock: asyncio.Lock = asyncio.Lock() + self.shard_id: Optional[int] = None + + def is_ratelimited(self) -> bool: current = time.time() if current > self.window + self.per: return False return self.remaining == 0 - def get_delay(self): + def get_delay(self) -> float: current = time.time() if current > self.window + self.per: @@ -103,7 +115,7 @@ class GatewayRatelimiter: return 0.0 - async def block(self): + async def block(self) -> None: async with self.lock: delta = self.get_delay() if delta: @@ -112,27 +124,31 @@ class GatewayRatelimiter: class KeepAliveHandler(threading.Thread): - def __init__(self, *args, **kwargs): - ws = kwargs.pop('ws', None) - interval = kwargs.pop('interval', None) - shard_id = kwargs.pop('shard_id', None) - threading.Thread.__init__(self, *args, **kwargs) - self.ws = ws - self._main_thread_id = ws.thread_id - self.interval = interval - self.daemon = True - self.shard_id = shard_id - self.msg = 'Keeping shard ID %s websocket alive with sequence %s.' - self.block_msg = 'Shard ID %s heartbeat blocked for more than %s seconds.' - self.behind_msg = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.' - self._stop_ev = threading.Event() - self._last_ack = time.perf_counter() - self._last_send = time.perf_counter() - self._last_recv = time.perf_counter() - self.latency = float('inf') - self.heartbeat_timeout = ws._max_heartbeat_timeout - - def run(self): + def __init__( + self, + *args: Any, + ws: DiscordWebSocket, + interval: Optional[float] = None, + shard_id: Optional[int] = None, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self.ws: DiscordWebSocket = ws + self._main_thread_id: int = ws.thread_id + self.interval: Optional[float] = interval + self.daemon: bool = True + self.shard_id: Optional[int] = shard_id + self.msg: str = 'Keeping shard ID %s websocket alive with sequence %s.' + self.block_msg: str = 'Shard ID %s heartbeat blocked for more than %s seconds.' + self.behind_msg: str = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.' + self._stop_ev: threading.Event = threading.Event() + self._last_ack: float = time.perf_counter() + self._last_send: float = time.perf_counter() + self._last_recv: float = time.perf_counter() + self.latency: float = float('inf') + self.heartbeat_timeout: float = ws._max_heartbeat_timeout + + def run(self) -> None: while not self._stop_ev.wait(self.interval): if self._last_recv + self.heartbeat_timeout < time.perf_counter(): _log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id) @@ -174,19 +190,19 @@ class KeepAliveHandler(threading.Thread): else: self._last_send = time.perf_counter() - def get_payload(self): + def get_payload(self) -> Dict[str, Any]: return { 'op': self.ws.HEARTBEAT, 'd': self.ws.sequence, } - def stop(self): + def stop(self) -> None: self._stop_ev.set() - def tick(self): + def tick(self) -> None: self._last_recv = time.perf_counter() - def ack(self): + def ack(self) -> None: ack_time = time.perf_counter() self._last_ack = ack_time self.latency = ack_time - self._last_send @@ -195,20 +211,20 @@ class KeepAliveHandler(threading.Thread): class VoiceKeepAliveHandler(KeepAliveHandler): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self.recent_ack_latencies = deque(maxlen=20) - self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.' - self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds' - self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind' + self.recent_ack_latencies: Deque[float] = deque(maxlen=20) + self.msg: str = 'Keeping shard ID %s voice websocket alive with timestamp %s.' + self.block_msg: str = 'Shard ID %s voice heartbeat blocked for more than %s seconds' + self.behind_msg: str = 'High socket latency, shard ID %s heartbeat is %.1fs behind' - def get_payload(self): + def get_payload(self) -> Dict[str, Any]: return { 'op': self.ws.HEARTBEAT, 'd': int(time.time() * 1000), } - def ack(self): + def ack(self) -> None: ack_time = time.perf_counter() self._last_ack = ack_time self._last_recv = ack_time @@ -221,6 +237,9 @@ class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse): return await super().close(code=code, message=message) +DWS = TypeVar('DWS', bound='DiscordWebSocket') + + class DiscordWebSocket: """Implements a WebSocket for Discord's gateway v6. @@ -261,6 +280,17 @@ class DiscordWebSocket: The authentication token for discord. """ + if TYPE_CHECKING: + token: Optional[str] + _connection: ConnectionState + _discord_parsers: Dict[str, Callable[..., Any]] + call_hooks: Callable[..., Any] + _initial_identify: bool + shard_id: Optional[int] + shard_count: Optional[int] + gateway: str + _max_heartbeat_timeout: float + # fmt: off DISPATCH = 0 HEARTBEAT = 1 @@ -277,51 +307,51 @@ class DiscordWebSocket: GUILD_SYNC = 12 # fmt: on - def __init__(self, socket, *, loop): - self.socket = socket - self.loop = loop + def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None: + self.socket: aiohttp.ClientWebSocketResponse = socket + self.loop: asyncio.AbstractEventLoop = loop # an empty dispatcher to prevent crashes - self._dispatch = lambda *args: None + self._dispatch: Callable[..., Any] = lambda *args: None # generic event listeners - self._dispatch_listeners = [] + self._dispatch_listeners: List[EventListener] = [] # the keep alive - self._keep_alive = None - self.thread_id = threading.get_ident() + self._keep_alive: Optional[KeepAliveHandler] = None + self.thread_id: int = threading.get_ident() # ws related stuff - self.session_id = None - self.sequence = None - self._zlib = zlib.decompressobj() - self._buffer = bytearray() - self._close_code = None - self._rate_limiter = GatewayRatelimiter() + self.session_id: Optional[str] = None + self.sequence: Optional[int] = None + self._zlib: zlib._Decompress = zlib.decompressobj() + self._buffer: bytearray = bytearray() + self._close_code: Optional[int] = None + self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter() @property - def open(self): + def open(self) -> bool: return not self.socket.closed - def is_ratelimited(self): + def is_ratelimited(self) -> bool: return self._rate_limiter.is_ratelimited() - def debug_log_receive(self, data, /): + def debug_log_receive(self, data: Dict[str, Any], /) -> None: self._dispatch('socket_raw_receive', data) - def log_receive(self, _, /): + def log_receive(self, _: Dict[str, Any], /) -> None: pass @classmethod async def from_client( - cls, - client, + cls: Type[DWS], + client: Client, *, - initial=False, - gateway=None, - shard_id=None, - session=None, - sequence=None, - resume=False, - ): + initial: bool = False, + gateway: Optional[str] = None, + shard_id: Optional[int] = None, + session: Optional[str] = None, + sequence: Optional[int] = None, + resume: bool = False, + ) -> DWS: """Creates a main websocket for Discord from a :class:`Client`. This is for internal use only. @@ -363,7 +393,12 @@ class DiscordWebSocket: await ws.resume() return ws - def wait_for(self, event, predicate, result=None): + def wait_for( + self, + event: str, + predicate: Callable[[Dict[str, Any]], bool], + result: Optional[Callable[[Dict[str, Any]], Any]] = None, + ) -> asyncio.Future[Any]: """Waits for a DISPATCH'd event that meets the predicate. Parameters @@ -388,7 +423,7 @@ class DiscordWebSocket: self._dispatch_listeners.append(entry) return future - async def identify(self): + async def identify(self) -> None: """Sends the IDENTIFY packet.""" payload = { 'op': self.IDENTIFY, @@ -426,7 +461,7 @@ class DiscordWebSocket: await self.send_as_json(payload) _log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id) - async def resume(self): + async def resume(self) -> None: """Sends the RESUME packet.""" payload = { 'op': self.RESUME, @@ -440,7 +475,7 @@ class DiscordWebSocket: await self.send_as_json(payload) _log.info('Shard ID %s has sent the RESUME payload.', self.shard_id) - async def received_message(self, msg, /): + async def received_message(self, msg: Any, /) -> None: if type(msg) is bytes: self._buffer.extend(msg) @@ -566,16 +601,16 @@ class DiscordWebSocket: del self._dispatch_listeners[index] @property - def latency(self): + def latency(self) -> float: """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.""" heartbeat = self._keep_alive return float('inf') if heartbeat is None else heartbeat.latency - def _can_handle_close(self): + def _can_handle_close(self) -> bool: code = self._close_code or self.socket.close_code return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014) - async def poll_event(self): + async def poll_event(self) -> None: """Polls for a DISPATCH event and handles the general gateway loop. Raises @@ -613,23 +648,23 @@ class DiscordWebSocket: _log.info('Websocket closed with %s, cannot reconnect.', code) raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None - async def debug_send(self, data, /): + async def debug_send(self, data: str, /) -> None: await self._rate_limiter.block() self._dispatch('socket_raw_send', data) await self.socket.send_str(data) - async def send(self, data, /): + async def send(self, data: str, /) -> None: await self._rate_limiter.block() await self.socket.send_str(data) - async def send_as_json(self, data): + async def send_as_json(self, data: Any) -> None: try: await self.send(utils._to_json(data)) except RuntimeError as exc: if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc - async def send_heartbeat(self, data): + async def send_heartbeat(self, data: Any) -> None: # This bypasses the rate limit handling code since it has a higher priority try: await self.socket.send_str(utils._to_json(data)) @@ -637,13 +672,19 @@ class DiscordWebSocket: if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc - async def change_presence(self, *, activity=None, status=None, since=0.0): + async def change_presence( + self, + *, + activity: Optional[BaseActivity] = None, + status: Optional[str] = None, + since: float = 0.0, + ) -> None: if activity is not None: if not isinstance(activity, BaseActivity): raise InvalidArgument('activity must derive from BaseActivity.') - activity = [activity.to_dict()] + activities = [activity.to_dict()] else: - activity = [] + activities = [] if status == 'idle': since = int(time.time() * 1000) @@ -651,7 +692,7 @@ class DiscordWebSocket: payload = { 'op': self.PRESENCE, 'd': { - 'activities': activity, + 'activities': activities, 'afk': False, 'since': since, 'status': status, @@ -662,7 +703,16 @@ class DiscordWebSocket: _log.debug('Sending "%s" to change status', sent) await self.send(sent) - async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None): + async def request_chunks( + self, + guild_id: int, + query: Optional[str] = None, + *, + limit: int, + user_ids: Optional[List[int]] = None, + presences: bool = False, + nonce: Optional[str] = None, + ) -> None: payload = { 'op': self.REQUEST_MEMBERS, 'd': { @@ -683,7 +733,13 @@ class DiscordWebSocket: await self.send_as_json(payload) - async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): + async def voice_state( + self, + guild_id: int, + channel_id: int, + self_mute: bool = False, + self_deaf: bool = False, + ) -> None: payload = { 'op': self.VOICE_STATE, 'd': { @@ -697,7 +753,7 @@ class DiscordWebSocket: _log.debug('Updating our voice state to %s.', payload) await self.send_as_json(payload) - async def close(self, code=4000): + async def close(self, code: int = 4000) -> None: if self._keep_alive: self._keep_alive.stop() self._keep_alive = None @@ -706,6 +762,9 @@ class DiscordWebSocket: await self.socket.close(code=code) +DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket') + + class DiscordVoiceWebSocket: """Implements the websocket protocol for handling voice connections. @@ -737,6 +796,12 @@ class DiscordVoiceWebSocket: Receive only. Indicates a user has disconnected from voice. """ + if TYPE_CHECKING: + thread_id: int + _connection: VoiceClient + gateway: str + _max_heartbeat_timeout: float + # fmt: off IDENTIFY = 0 SELECT_PROTOCOL = 1 @@ -752,25 +817,31 @@ class DiscordVoiceWebSocket: CLIENT_DISCONNECT = 13 # fmt: on - def __init__(self, socket, loop, *, hook=None): + def __init__( + self, + socket: aiohttp.ClientWebSocketResponse, + loop: asyncio.AbstractEventLoop, + *, + hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None, + ) -> None: self.ws = socket self.loop = loop self._keep_alive = None self._close_code = None self.secret_key = None if hook: - self._hook = hook + self._hook = hook # type: ignore - type-checker doesn't like overriding methods - async def _hook(self, *args): + async def _hook(self, *args: Any) -> None: pass - async def send_as_json(self, data): + async def send_as_json(self, data: Any) -> None: _log.debug('Sending voice websocket frame: %s.', data) await self.ws.send_str(utils._to_json(data)) send_heartbeat = send_as_json - async def resume(self): + async def resume(self) -> None: state = self._connection payload = { 'op': self.RESUME, @@ -782,7 +853,7 @@ class DiscordVoiceWebSocket: } await self.send_as_json(payload) - async def identify(self): + async def identify(self) -> None: state = self._connection payload = { 'op': self.IDENTIFY, @@ -796,7 +867,7 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) @classmethod - async def from_client(cls, client, *, resume=False, hook=None): + async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume=False, hook=None) -> DVWS: """Creates a voice websocket for the :class:`VoiceClient`.""" gateway = 'wss://' + client.endpoint + '/?v=4' http = client._state.http @@ -814,7 +885,7 @@ class DiscordVoiceWebSocket: return ws - async def select_protocol(self, ip, port, mode): + async def select_protocol(self, ip: str, port: int, mode: int) -> None: payload = { 'op': self.SELECT_PROTOCOL, 'd': { @@ -829,7 +900,7 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) - async def client_connect(self): + async def client_connect(self) -> None: payload = { 'op': self.CLIENT_CONNECT, 'd': { @@ -839,7 +910,7 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) - async def speak(self, state=SpeakingState.voice): + async def speak(self, state: SpeakingState = SpeakingState.voice) -> None: payload = { 'op': self.SPEAKING, 'd': { @@ -850,28 +921,29 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) - async def received_message(self, msg): + async def received_message(self, msg: Dict[str, Any]) -> None: _log.debug('Voice websocket frame received: %s', msg) op = msg['op'] data = msg.get('d') if op == self.READY: - await self.initial_connection(data) + await self.initial_connection(data) # type: ignore - type-checker thinks data could be None elif op == self.HEARTBEAT_ACK: - self._keep_alive.ack() + self._keep_alive.ack() # type: ignore - _keep_alive can't be None at this point elif op == self.RESUMED: _log.info('Voice RESUME succeeded.') elif op == self.SESSION_DESCRIPTION: - self._connection.mode = data['mode'] - await self.load_secret_key(data) + # type-checker thinks data could be None + self._connection.mode = data['mode'] # type: ignore + await self.load_secret_key(data) # type: ignore elif op == self.HELLO: - interval = data['heartbeat_interval'] / 1000.0 + interval = data['heartbeat_interval'] / 1000.0 # type: ignore - type-checker thinks data could be None self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0)) self._keep_alive.start() await self._hook(self, msg) - async def initial_connection(self, data): + async def initial_connection(self, data: Dict[str, Any]) -> None: state = self._connection state.ssrc = data['ssrc'] state.voice_port = data['port'] @@ -888,41 +960,41 @@ class DiscordVoiceWebSocket: # the ip is ascii starting at the 4th byte and ending at the first null ip_start = 4 ip_end = recv.index(0, ip_start) - state.ip = recv[ip_start:ip_end].decode('ascii') + state.endpoint_ip = recv[ip_start:ip_end].decode('ascii') - state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0] - _log.debug('detected ip: %s port: %s', state.ip, state.port) + state.voice_port = struct.unpack_from('>H', recv, len(recv) - 2)[0] + _log.debug('detected ip: %s port: %s', state.endpoint_ip, state.voice_port) # there *should* always be at least one supported mode (xsalsa20_poly1305) modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes] _log.debug('received supported encryption modes: %s', ", ".join(modes)) mode = modes[0] - await self.select_protocol(state.ip, state.port, mode) + await self.select_protocol(state.endpoint_ip, state.voice_port, mode) _log.info('selected the voice protocol for use (%s)', mode) @property - def latency(self): + def latency(self) -> float: """:class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds.""" heartbeat = self._keep_alive return float('inf') if heartbeat is None else heartbeat.latency @property - def average_latency(self): - """:class:`list`: Average of last 20 HEARTBEAT latencies.""" + def average_latency(self) -> float: + """:class:`float`: Average of last 20 HEARTBEAT latencies.""" heartbeat = self._keep_alive if heartbeat is None or not heartbeat.recent_ack_latencies: return float('inf') return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies) - async def load_secret_key(self, data): + async def load_secret_key(self, data: Dict[str, Any]) -> None: _log.info('received secret key for voice connection') - self.secret_key = self._connection.secret_key = data.get('secret_key') + self.secret_key = self._connection.secret_key = data.get('secret_key') # type: ignore - type-checker thinks secret_key could be None await self.speak() - await self.speak(False) + await self.speak(SpeakingState.none) - async def poll_event(self): + async def poll_event(self) -> None: # This exception is handled up the chain msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) if msg.type is aiohttp.WSMsgType.TEXT: @@ -934,7 +1006,7 @@ class DiscordVoiceWebSocket: _log.debug('Received %s', msg) raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) - async def close(self, code=1000): + async def close(self, code: int = 1000) -> None: if self._keep_alive is not None: self._keep_alive.stop() diff --git a/discord/http.py b/discord/http.py index c5a9f5c9e..768b93713 100644 --- a/discord/http.py +++ b/discord/http.py @@ -341,7 +341,7 @@ class HTTPClient: connector=self.connector, ws_response_class=DiscordClientWebSocketResponse ) - async def ws_connect(self, url: str, *, compress: int = 0) -> Any: + async def ws_connect(self, url: str, *, compress: int = 0) -> aiohttp.ClientWebSocketResponse: kwargs = { 'proxy_auth': self.proxy_auth, 'proxy': self.proxy, diff --git a/discord/voice_client.py b/discord/voice_client.py index b29c4f1a9..13cdd9c03 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -239,6 +239,7 @@ class VoiceClient(VoiceProtocol): super().__init__(client, channel) state = client._connection self.token: str = MISSING + self.server_id: int = MISSING self.socket = MISSING self.loop: asyncio.AbstractEventLoop = state.loop self._state: ConnectionState = state