From efab542aba3251d4f16e2c4f63a09b04c34c1fec Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sat, 17 Sep 2022 22:49:29 -0400 Subject: [PATCH] Parse gateway URL as an actual URL using yarl Discord has changed the URL format to make it infeasible to edit it using basic string interpolation. --- discord/gateway.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/discord/gateway.py b/discord/gateway.py index 69892541e..2dd6facbf 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -35,6 +35,7 @@ import zlib from typing import Any, Callable, Coroutine, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar import aiohttp +import yarl from . import utils from .activity import BaseActivity, Spotify @@ -286,14 +287,14 @@ class DiscordWebSocket: _initial_identify: bool shard_id: Optional[int] shard_count: Optional[int] - gateway: str + gateway: yarl.URL _max_heartbeat_timeout: float _user_agent: str _super_properties: Dict[str, Any] _zlib_enabled: bool # fmt: off - DEFAULT_GATEWAY = 'wss://gateway.discord.gg/' + DEFAULT_GATEWAY = yarl.URL('wss://gateway.discord.gg/') DISPATCH = 0 HEARTBEAT = 1 IDENTIFY = 2 @@ -355,7 +356,7 @@ class DiscordWebSocket: client: Client, *, initial: bool = False, - gateway: Optional[str] = None, + gateway: Optional[yarl.URL] = None, session: Optional[str] = None, sequence: Optional[int] = None, resume: bool = False, @@ -372,11 +373,11 @@ class DiscordWebSocket: gateway = gateway or cls.DEFAULT_GATEWAY if zlib: - url = f'{gateway}?v={INTERNAL_API_VERSION}&encoding={encoding}&compress=zlib-stream' + url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream') else: - url = f'{gateway}?v={INTERNAL_API_VERSION}&encoding={encoding}' + url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding) - socket = await client.http.ws_connect(url) + socket = await client.http.ws_connect(str(url)) ws = cls(socket, loop=client.loop) # Dynamically add attributes needed @@ -577,7 +578,7 @@ class DiscordWebSocket: self._trace = trace = data.get('_trace', []) self.sequence = msg['s'] self.session_id = data['session_id'] - self.gateway = data['resume_gateway_url'] + self.gateway = yarl.URL(data['resume_gateway_url']) _log.info('Connected to Gateway: %s (Session ID: %s).', ', '.join(trace), self.session_id) elif event == 'RESUMED':