Browse Source

Parse gateway URL as an actual URL using yarl

Discord has changed the URL format to make it infeasible to edit it
using basic string interpolation.
pull/8476/head
Rapptz 3 years ago
parent
commit
8aaeb6acfa
  1. 15
      discord/gateway.py
  2. 8
      discord/shard.py

15
discord/gateway.py

@ -37,6 +37,7 @@ import zlib
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar
import aiohttp
import yarl
from . import utils
from .activity import BaseActivity
@ -287,11 +288,11 @@ class DiscordWebSocket:
_initial_identify: bool
shard_id: Optional[int]
shard_count: Optional[int]
gateway: str
gateway: yarl.URL
_max_heartbeat_timeout: float
# fmt: off
DEFAULT_GATEWAY = 'wss://gateway.discord.gg/'
DEFAULT_GATEWAY = yarl.URL('wss://gateway.discord.gg/')
DISPATCH = 0
HEARTBEAT = 1
IDENTIFY = 2
@ -346,7 +347,7 @@ class DiscordWebSocket:
client: Client,
*,
initial: bool = False,
gateway: Optional[str] = None,
gateway: Optional[yarl.URL] = None,
shard_id: Optional[int] = None,
session: Optional[str] = None,
sequence: Optional[int] = None,
@ -364,11 +365,11 @@ class DiscordWebSocket:
gateway = gateway or cls.DEFAULT_GATEWAY
if zlib:
url = f'{gateway}?v={INTERNAL_API_VERSION}&encoding={encoding}&compress=zlib-stream'
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream')
else:
url = f'{gateway}?v={INTERNAL_API_VERSION}&encoding={encoding}'
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding)
socket = await client.http.ws_connect(url)
socket = await client.http.ws_connect(str(url))
ws = cls(socket, loop=client.loop)
# dynamically add attributes needed
@ -556,7 +557,7 @@ class DiscordWebSocket:
if event == 'READY':
self.sequence = msg['s']
self.session_id = data['session_id']
self.gateway = data['resume_gateway_url']
self.gateway = yarl.URL(data['resume_gateway_url'])
_log.info('Shard ID %s has connected to Gateway (Session ID: %s).', self.shard_id, self.session_id)
elif event == 'RESUMED':

8
discord/shard.py

@ -28,6 +28,7 @@ import asyncio
import logging
import aiohttp
import yarl
from .state import AutoShardedConnectionState
from .client import Client
@ -403,7 +404,7 @@ class AutoShardedClient(Client):
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()}
async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None:
async def launch_shard(self, gateway: yarl.URL, shard_id: int, *, initial: bool = False) -> None:
try:
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
ws = await asyncio.wait_for(coro, timeout=180.0)
@ -422,9 +423,10 @@ class AutoShardedClient(Client):
if self.shard_count is None:
self.shard_count: int
self.shard_count, gateway = await self.http.get_bot_gateway()
self.shard_count, gateway_url = await self.http.get_bot_gateway()
gateway = yarl.URL(gateway_url)
else:
gateway = await self.http.get_gateway()
gateway = DiscordWebSocket.DEFAULT_GATEWAY
self._connection.shard_count = self.shard_count

Loading…
Cancel
Save