From 8aaeb6acfaac36a79181bf7dab82b309de99d9d5 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sat, 17 Sep 2022 22:49:29 -0400 Subject: [PATCH] Parse gateway URL as an actual URL using yarl Discord has changed the URL format to make it infeasible to edit it using basic string interpolation. --- discord/gateway.py | 15 ++++++++------- discord/shard.py | 8 +++++--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/discord/gateway.py b/discord/gateway.py index ee4c9942d..a06195307 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -37,6 +37,7 @@ import zlib from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar import aiohttp +import yarl from . import utils from .activity import BaseActivity @@ -287,11 +288,11 @@ class DiscordWebSocket: _initial_identify: bool shard_id: Optional[int] shard_count: Optional[int] - gateway: str + gateway: yarl.URL _max_heartbeat_timeout: float # fmt: off - DEFAULT_GATEWAY = 'wss://gateway.discord.gg/' + DEFAULT_GATEWAY = yarl.URL('wss://gateway.discord.gg/') DISPATCH = 0 HEARTBEAT = 1 IDENTIFY = 2 @@ -346,7 +347,7 @@ class DiscordWebSocket: client: Client, *, initial: bool = False, - gateway: Optional[str] = None, + gateway: Optional[yarl.URL] = None, shard_id: Optional[int] = None, session: Optional[str] = None, sequence: Optional[int] = None, @@ -364,11 +365,11 @@ class DiscordWebSocket: gateway = gateway or cls.DEFAULT_GATEWAY if zlib: - url = f'{gateway}?v={INTERNAL_API_VERSION}&encoding={encoding}&compress=zlib-stream' + url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream') else: - url = f'{gateway}?v={INTERNAL_API_VERSION}&encoding={encoding}' + url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding) - socket = await client.http.ws_connect(url) + socket = await client.http.ws_connect(str(url)) ws = cls(socket, loop=client.loop) # dynamically add attributes needed @@ -556,7 +557,7 @@ class DiscordWebSocket: if event == 'READY': self.sequence = msg['s'] self.session_id = data['session_id'] - self.gateway = data['resume_gateway_url'] + self.gateway = yarl.URL(data['resume_gateway_url']) _log.info('Shard ID %s has connected to Gateway (Session ID: %s).', self.shard_id, self.session_id) elif event == 'RESUMED': diff --git a/discord/shard.py b/discord/shard.py index cdb036027..cddb2d29f 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -28,6 +28,7 @@ import asyncio import logging import aiohttp +import yarl from .state import AutoShardedConnectionState from .client import Client @@ -403,7 +404,7 @@ class AutoShardedClient(Client): """Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object.""" return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()} - async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None: + async def launch_shard(self, gateway: yarl.URL, shard_id: int, *, initial: bool = False) -> None: try: coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id) ws = await asyncio.wait_for(coro, timeout=180.0) @@ -422,9 +423,10 @@ class AutoShardedClient(Client): if self.shard_count is None: self.shard_count: int - self.shard_count, gateway = await self.http.get_bot_gateway() + self.shard_count, gateway_url = await self.http.get_bot_gateway() + gateway = yarl.URL(gateway_url) else: - gateway = await self.http.get_gateway() + gateway = DiscordWebSocket.DEFAULT_GATEWAY self._connection.shard_count = self.shard_count