|
|
@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
|
|
|
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
|
|
|
DEALINGS IN THE SOFTWARE. |
|
|
|
""" |
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
|
|
import asyncio |
|
|
@ -30,7 +31,6 @@ import struct |
|
|
|
import time |
|
|
|
import threading |
|
|
|
import traceback |
|
|
|
import zlib |
|
|
|
|
|
|
|
from typing import Any, Callable, Coroutine, Dict, List, TYPE_CHECKING, NamedTuple, Optional, Sequence, TypeVar, Tuple |
|
|
|
|
|
|
@ -257,7 +257,7 @@ class DiscordWebSocket: |
|
|
|
_max_heartbeat_timeout: float |
|
|
|
_user_agent: str |
|
|
|
_super_properties: Dict[str, Any] |
|
|
|
_zlib_enabled: bool |
|
|
|
_transport_compression: bool |
|
|
|
|
|
|
|
# fmt: off |
|
|
|
DEFAULT_GATEWAY = yarl.URL('wss://gateway.discord.gg/') |
|
|
@ -296,8 +296,7 @@ class DiscordWebSocket: |
|
|
|
# WS related stuff |
|
|
|
self.session_id: Optional[str] = None |
|
|
|
self.sequence: Optional[int] = None |
|
|
|
self._zlib: zlib._Decompress = zlib.decompressobj() |
|
|
|
self._buffer: bytearray = bytearray() |
|
|
|
self._decompressor: utils._DecompressionContext = utils._ActiveDecompressionContext() |
|
|
|
self._close_code: Optional[int] = None |
|
|
|
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter() |
|
|
|
|
|
|
@ -333,7 +332,7 @@ class DiscordWebSocket: |
|
|
|
sequence: Optional[int] = None, |
|
|
|
resume: bool = False, |
|
|
|
encoding: str = 'json', |
|
|
|
zlib: bool = True, |
|
|
|
compress: bool = True, |
|
|
|
) -> Self: |
|
|
|
"""Creates a main websocket for Discord from a :class:`Client`. |
|
|
|
|
|
|
@ -344,10 +343,12 @@ class DiscordWebSocket: |
|
|
|
|
|
|
|
gateway = gateway or cls.DEFAULT_GATEWAY |
|
|
|
|
|
|
|
if zlib: |
|
|
|
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream') |
|
|
|
else: |
|
|
|
if not compress: |
|
|
|
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding) |
|
|
|
else: |
|
|
|
url = gateway.with_query( |
|
|
|
v=INTERNAL_API_VERSION, encoding=encoding, compress=utils._ActiveDecompressionContext.COMPRESSION_TYPE |
|
|
|
) |
|
|
|
|
|
|
|
socket = await client.http.ws_connect(str(url)) |
|
|
|
ws = cls(socket, loop=client.loop) |
|
|
@ -365,7 +366,7 @@ class DiscordWebSocket: |
|
|
|
ws._max_heartbeat_timeout = client._connection.heartbeat_timeout |
|
|
|
ws._user_agent = client.http.user_agent |
|
|
|
ws._super_properties = client.http.super_properties |
|
|
|
ws._zlib_enabled = zlib |
|
|
|
ws._transport_compression = compress |
|
|
|
ws.afk = client._connection._afk |
|
|
|
ws.idle_since = client._connection._idle_since |
|
|
|
|
|
|
@ -449,7 +450,7 @@ class DiscordWebSocket: |
|
|
|
'capabilities': self.capabilities.value, |
|
|
|
'properties': self._super_properties, |
|
|
|
'presence': presence, |
|
|
|
'compress': not self._zlib_enabled, # We require at least one form of compression |
|
|
|
'compress': not self._transport_compression, # We require at least one form of compression |
|
|
|
'client_state': { |
|
|
|
'api_code_version': 0, |
|
|
|
'guild_versions': {}, |
|
|
@ -483,13 +484,11 @@ class DiscordWebSocket: |
|
|
|
|
|
|
|
async def received_message(self, msg: Any, /) -> None: |
|
|
|
if type(msg) is bytes: |
|
|
|
self._buffer.extend(msg) |
|
|
|
msg = self._decompressor.decompress(msg) |
|
|
|
|
|
|
|
if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff': |
|
|
|
# Received a partial gateway message |
|
|
|
if msg is None: |
|
|
|
return |
|
|
|
msg = self._zlib.decompress(self._buffer) |
|
|
|
msg = msg.decode('utf-8') |
|
|
|
self._buffer = bytearray() |
|
|
|
|
|
|
|
self.log_receive(msg) |
|
|
|
msg = utils._from_json(msg) |
|
|
|