Browse Source

Add zstd gateway compression to speed profile

pull/10109/head
Lilly Rose Berner 7 months ago
committed by dolfies
parent
commit
e4e708535c
  1. 29
      discord/gateway.py
  2. 62
      discord/utils.py
  3. 1
      setup.py

29
discord/gateway.py

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
@ -30,7 +31,6 @@ import struct
import time
import threading
import traceback
import zlib
from typing import Any, Callable, Coroutine, Dict, List, TYPE_CHECKING, NamedTuple, Optional, Sequence, TypeVar, Tuple
@ -257,7 +257,7 @@ class DiscordWebSocket:
_max_heartbeat_timeout: float
_user_agent: str
_super_properties: Dict[str, Any]
_zlib_enabled: bool
_transport_compression: bool
# fmt: off
DEFAULT_GATEWAY = yarl.URL('wss://gateway.discord.gg/')
@ -296,8 +296,7 @@ class DiscordWebSocket:
# WS related stuff
self.session_id: Optional[str] = None
self.sequence: Optional[int] = None
self._zlib: zlib._Decompress = zlib.decompressobj()
self._buffer: bytearray = bytearray()
self._decompressor: utils._DecompressionContext = utils._ActiveDecompressionContext()
self._close_code: Optional[int] = None
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
@ -333,7 +332,7 @@ class DiscordWebSocket:
sequence: Optional[int] = None,
resume: bool = False,
encoding: str = 'json',
zlib: bool = True,
compress: bool = True,
) -> Self:
"""Creates a main websocket for Discord from a :class:`Client`.
@ -344,10 +343,12 @@ class DiscordWebSocket:
gateway = gateway or cls.DEFAULT_GATEWAY
if zlib:
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream')
else:
if not compress:
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding)
else:
url = gateway.with_query(
v=INTERNAL_API_VERSION, encoding=encoding, compress=utils._ActiveDecompressionContext.COMPRESSION_TYPE
)
socket = await client.http.ws_connect(str(url))
ws = cls(socket, loop=client.loop)
@ -365,7 +366,7 @@ class DiscordWebSocket:
ws._max_heartbeat_timeout = client._connection.heartbeat_timeout
ws._user_agent = client.http.user_agent
ws._super_properties = client.http.super_properties
ws._zlib_enabled = zlib
ws._transport_compression = compress
ws.afk = client._connection._afk
ws.idle_since = client._connection._idle_since
@ -449,7 +450,7 @@ class DiscordWebSocket:
'capabilities': self.capabilities.value,
'properties': self._super_properties,
'presence': presence,
'compress': not self._zlib_enabled, # We require at least one form of compression
'compress': not self._transport_compression, # We require at least one form of compression
'client_state': {
'api_code_version': 0,
'guild_versions': {},
@ -483,13 +484,11 @@ class DiscordWebSocket:
async def received_message(self, msg: Any, /) -> None:
if type(msg) is bytes:
self._buffer.extend(msg)
msg = self._decompressor.decompress(msg)
if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff':
# Received a partial gateway message
if msg is None:
return
msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8')
self._buffer = bytearray()
self.log_receive(msg)
msg = utils._from_json(msg)

62
discord/utils.py

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import array
@ -40,7 +41,6 @@ from typing import (
Iterator,
List,
Literal,
Mapping,
NamedTuple,
Optional,
Protocol,
@ -75,6 +75,8 @@ import types
import typing
import warnings
import aiohttp
import logging
import zlib
import yarl
@ -85,8 +87,15 @@ except ModuleNotFoundError:
else:
HAS_ORJSON = True
from .enums import Locale, try_enum
try:
import zstandard # type: ignore
except ImportError:
HAS_ZSTD = False
else:
HAS_ZSTD = True
from .enums import Locale, try_enum
__all__ = (
'oauth_url',
@ -161,8 +170,11 @@ if TYPE_CHECKING:
from .commands import ApplicationCommand
from .entitlements import Gift
class _RequestLike(Protocol):
headers: Mapping[str, Any]
class _DecompressionContext(Protocol):
COMPRESSION_TYPE: str
def decompress(self, data: bytes, /) -> str | None:
...
P = ParamSpec('P')
@ -1735,3 +1747,45 @@ else:
return unsigned_val
else:
return -((unsigned_val ^ 0xFFFFFFFF) + 1)
if HAS_ZSTD:
class _ZstdDecompressionContext:
__slots__ = ('context',)
COMPRESSION_TYPE: str = 'zstd-stream'
def __init__(self) -> None:
decompressor = zstandard.ZstdDecompressor()
self.context = decompressor.decompressobj()
def decompress(self, data: bytes, /) -> str | None:
# Each WS message is a complete gateway message
return self.context.decompress(data).decode('utf-8')
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZstdDecompressionContext
else:
class _ZlibDecompressionContext:
__slots__ = ('context', 'buffer')
COMPRESSION_TYPE: str = 'zlib-stream'
def __init__(self) -> None:
self.buffer: bytearray = bytearray()
self.context = zlib.decompressobj()
def decompress(self, data: bytes, /) -> str | None:
self.buffer.extend(data)
# Check whether ending is Z_SYNC_FLUSH
if len(data) < 4 or data[-4:] != b'\x00\x00\xff\xff':
return
msg = self.context.decompress(self.buffer)
self.buffer = bytearray()
return msg.decode('utf-8')
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZlibDecompressionContext

1
setup.py

@ -56,6 +56,7 @@ extras_require = {
'Brotli',
'cchardet==2.1.7; python_version < "3.10"',
'mmh3>=2.5',
'zstandard>=0.23.0',
],
'test': [
'coverage[toml]',

Loading…
Cancel
Save