Browse Source

Add zstd gateway compression to speed profile

pull/10109/head
Lilly Rose Berner 7 months ago
committed by dolfies
parent
commit
e4e708535c
  1. 29
      discord/gateway.py
  2. 62
      discord/utils.py
  3. 1
      setup.py

29
discord/gateway.py

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@ -30,7 +31,6 @@ import struct
import time import time
import threading import threading
import traceback import traceback
import zlib
from typing import Any, Callable, Coroutine, Dict, List, TYPE_CHECKING, NamedTuple, Optional, Sequence, TypeVar, Tuple from typing import Any, Callable, Coroutine, Dict, List, TYPE_CHECKING, NamedTuple, Optional, Sequence, TypeVar, Tuple
@ -257,7 +257,7 @@ class DiscordWebSocket:
_max_heartbeat_timeout: float _max_heartbeat_timeout: float
_user_agent: str _user_agent: str
_super_properties: Dict[str, Any] _super_properties: Dict[str, Any]
_zlib_enabled: bool _transport_compression: bool
# fmt: off # fmt: off
DEFAULT_GATEWAY = yarl.URL('wss://gateway.discord.gg/') DEFAULT_GATEWAY = yarl.URL('wss://gateway.discord.gg/')
@ -296,8 +296,7 @@ class DiscordWebSocket:
# WS related stuff # WS related stuff
self.session_id: Optional[str] = None self.session_id: Optional[str] = None
self.sequence: Optional[int] = None self.sequence: Optional[int] = None
self._zlib: zlib._Decompress = zlib.decompressobj() self._decompressor: utils._DecompressionContext = utils._ActiveDecompressionContext()
self._buffer: bytearray = bytearray()
self._close_code: Optional[int] = None self._close_code: Optional[int] = None
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter() self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
@ -333,7 +332,7 @@ class DiscordWebSocket:
sequence: Optional[int] = None, sequence: Optional[int] = None,
resume: bool = False, resume: bool = False,
encoding: str = 'json', encoding: str = 'json',
zlib: bool = True, compress: bool = True,
) -> Self: ) -> Self:
"""Creates a main websocket for Discord from a :class:`Client`. """Creates a main websocket for Discord from a :class:`Client`.
@ -344,10 +343,12 @@ class DiscordWebSocket:
gateway = gateway or cls.DEFAULT_GATEWAY gateway = gateway or cls.DEFAULT_GATEWAY
if zlib: if not compress:
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream')
else:
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding) url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding)
else:
url = gateway.with_query(
v=INTERNAL_API_VERSION, encoding=encoding, compress=utils._ActiveDecompressionContext.COMPRESSION_TYPE
)
socket = await client.http.ws_connect(str(url)) socket = await client.http.ws_connect(str(url))
ws = cls(socket, loop=client.loop) ws = cls(socket, loop=client.loop)
@ -365,7 +366,7 @@ class DiscordWebSocket:
ws._max_heartbeat_timeout = client._connection.heartbeat_timeout ws._max_heartbeat_timeout = client._connection.heartbeat_timeout
ws._user_agent = client.http.user_agent ws._user_agent = client.http.user_agent
ws._super_properties = client.http.super_properties ws._super_properties = client.http.super_properties
ws._zlib_enabled = zlib ws._transport_compression = compress
ws.afk = client._connection._afk ws.afk = client._connection._afk
ws.idle_since = client._connection._idle_since ws.idle_since = client._connection._idle_since
@ -449,7 +450,7 @@ class DiscordWebSocket:
'capabilities': self.capabilities.value, 'capabilities': self.capabilities.value,
'properties': self._super_properties, 'properties': self._super_properties,
'presence': presence, 'presence': presence,
'compress': not self._zlib_enabled, # We require at least one form of compression 'compress': not self._transport_compression, # We require at least one form of compression
'client_state': { 'client_state': {
'api_code_version': 0, 'api_code_version': 0,
'guild_versions': {}, 'guild_versions': {},
@ -483,13 +484,11 @@ class DiscordWebSocket:
async def received_message(self, msg: Any, /) -> None: async def received_message(self, msg: Any, /) -> None:
if type(msg) is bytes: if type(msg) is bytes:
self._buffer.extend(msg) msg = self._decompressor.decompress(msg)
if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff': # Received a partial gateway message
if msg is None:
return return
msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8')
self._buffer = bytearray()
self.log_receive(msg) self.log_receive(msg)
msg = utils._from_json(msg) msg = utils._from_json(msg)

62
discord/utils.py

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
import array import array
@ -40,7 +41,6 @@ from typing import (
Iterator, Iterator,
List, List,
Literal, Literal,
Mapping,
NamedTuple, NamedTuple,
Optional, Optional,
Protocol, Protocol,
@ -75,6 +75,8 @@ import types
import typing import typing
import warnings import warnings
import aiohttp import aiohttp
import logging
import zlib
import yarl import yarl
@ -85,8 +87,15 @@ except ModuleNotFoundError:
else: else:
HAS_ORJSON = True HAS_ORJSON = True
from .enums import Locale, try_enum
try:
import zstandard # type: ignore
except ImportError:
HAS_ZSTD = False
else:
HAS_ZSTD = True
from .enums import Locale, try_enum
__all__ = ( __all__ = (
'oauth_url', 'oauth_url',
@ -161,8 +170,11 @@ if TYPE_CHECKING:
from .commands import ApplicationCommand from .commands import ApplicationCommand
from .entitlements import Gift from .entitlements import Gift
class _RequestLike(Protocol): class _DecompressionContext(Protocol):
headers: Mapping[str, Any] COMPRESSION_TYPE: str
def decompress(self, data: bytes, /) -> str | None:
...
P = ParamSpec('P') P = ParamSpec('P')
@ -1735,3 +1747,45 @@ else:
return unsigned_val return unsigned_val
else: else:
return -((unsigned_val ^ 0xFFFFFFFF) + 1) return -((unsigned_val ^ 0xFFFFFFFF) + 1)
if HAS_ZSTD:
class _ZstdDecompressionContext:
__slots__ = ('context',)
COMPRESSION_TYPE: str = 'zstd-stream'
def __init__(self) -> None:
decompressor = zstandard.ZstdDecompressor()
self.context = decompressor.decompressobj()
def decompress(self, data: bytes, /) -> str | None:
# Each WS message is a complete gateway message
return self.context.decompress(data).decode('utf-8')
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZstdDecompressionContext
else:
class _ZlibDecompressionContext:
__slots__ = ('context', 'buffer')
COMPRESSION_TYPE: str = 'zlib-stream'
def __init__(self) -> None:
self.buffer: bytearray = bytearray()
self.context = zlib.decompressobj()
def decompress(self, data: bytes, /) -> str | None:
self.buffer.extend(data)
# Check whether ending is Z_SYNC_FLUSH
if len(data) < 4 or data[-4:] != b'\x00\x00\xff\xff':
return
msg = self.context.decompress(self.buffer)
self.buffer = bytearray()
return msg.decode('utf-8')
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZlibDecompressionContext

1
setup.py

@ -56,6 +56,7 @@ extras_require = {
'Brotli', 'Brotli',
'cchardet==2.1.7; python_version < "3.10"', 'cchardet==2.1.7; python_version < "3.10"',
'mmh3>=2.5', 'mmh3>=2.5',
'zstandard>=0.23.0',
], ],
'test': [ 'test': [
'coverage[toml]', 'coverage[toml]',

Loading…
Cancel
Save