Browse Source

Add zstd gateway compression to speed profile

pull/9963/head
Lilly Rose Berner 6 months ago
committed by GitHub
parent
commit
91f300a28a
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 23
      discord/gateway.py
  2. 19
      discord/http.py
  3. 58
      discord/utils.py
  4. 1
      pyproject.toml

23
discord/gateway.py

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
@ -32,7 +33,6 @@ import sys
import time
import threading
import traceback
import zlib
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple
@ -325,8 +325,7 @@ class DiscordWebSocket:
# ws related stuff
self.session_id: Optional[str] = None
self.sequence: Optional[int] = None
self._zlib: zlib._Decompress = zlib.decompressobj()
self._buffer: bytearray = bytearray()
self._decompressor: utils._DecompressionContext = utils._ActiveDecompressionContext()
self._close_code: Optional[int] = None
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
@ -355,7 +354,7 @@ class DiscordWebSocket:
sequence: Optional[int] = None,
resume: bool = False,
encoding: str = 'json',
zlib: bool = True,
compress: bool = True,
) -> Self:
"""Creates a main websocket for Discord from a :class:`Client`.
@ -366,10 +365,12 @@ class DiscordWebSocket:
gateway = gateway or cls.DEFAULT_GATEWAY
if zlib:
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream')
else:
if not compress:
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding)
else:
url = gateway.with_query(
v=INTERNAL_API_VERSION, encoding=encoding, compress=utils._ActiveDecompressionContext.COMPRESSION_TYPE
)
socket = await client.http.ws_connect(str(url))
ws = cls(socket, loop=client.loop)
@ -488,13 +489,11 @@ class DiscordWebSocket:
async def received_message(self, msg: Any, /) -> None:
if type(msg) is bytes:
self._buffer.extend(msg)
msg = self._decompressor.decompress(msg)
if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff':
# Received a partial gateway message
if msg is None:
return
msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8')
self._buffer = bytearray()
self.log_receive(msg)
msg = utils._from_json(msg)

19
discord/http.py

@ -2701,28 +2701,13 @@ class HTTPClient:
# Misc
async def get_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> str:
try:
data = await self.request(Route('GET', '/gateway'))
except HTTPException as exc:
raise GatewayNotFound() from exc
if zlib:
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
else:
value = '{0}?encoding={1}&v={2}'
return value.format(data['url'], encoding, INTERNAL_API_VERSION)
async def get_bot_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> Tuple[int, str]:
async def get_bot_gateway(self) -> Tuple[int, str]:
try:
data = await self.request(Route('GET', '/gateway/bot'))
except HTTPException as exc:
raise GatewayNotFound() from exc
if zlib:
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
else:
value = '{0}?encoding={1}&v={2}'
return data['shards'], value.format(data['url'], encoding, INTERNAL_API_VERSION)
return data['shards'], data['url']
def get_user(self, user_id: Snowflake) -> Response[user.User]:
return self.request(Route('GET', '/users/{user_id}', user_id=user_id))

58
discord/utils.py

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import array
@ -41,7 +42,6 @@ from typing import (
Iterator,
List,
Literal,
Mapping,
NamedTuple,
Optional,
Protocol,
@ -71,6 +71,7 @@ import types
import typing
import warnings
import logging
import zlib
import yarl
@ -81,6 +82,12 @@ except ModuleNotFoundError:
else:
HAS_ORJSON = True
try:
import zstandard # type: ignore
except ImportError:
_HAS_ZSTD = False
else:
_HAS_ZSTD = True
__all__ = (
'oauth_url',
@ -148,8 +155,11 @@ if TYPE_CHECKING:
from .invite import Invite
from .template import Template
class _RequestLike(Protocol):
headers: Mapping[str, Any]
class _DecompressionContext(Protocol):
COMPRESSION_TYPE: str
def decompress(self, data: bytes, /) -> str | None:
...
P = ParamSpec('P')
@ -1416,3 +1426,45 @@ def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'o
return f'{seq[0]} {final} {seq[1]}'
return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}'
if _HAS_ZSTD:
class _ZstdDecompressionContext:
__slots__ = ('context',)
COMPRESSION_TYPE: str = 'zstd-stream'
def __init__(self) -> None:
decompressor = zstandard.ZstdDecompressor()
self.context = decompressor.decompressobj()
def decompress(self, data: bytes, /) -> str | None:
# Each WS message is a complete gateway message
return self.context.decompress(data).decode('utf-8')
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZstdDecompressionContext
else:
class _ZlibDecompressionContext:
__slots__ = ('context', 'buffer')
COMPRESSION_TYPE: str = 'zlib-stream'
def __init__(self) -> None:
self.buffer: bytearray = bytearray()
self.context = zlib.decompressobj()
def decompress(self, data: bytes, /) -> str | None:
self.buffer.extend(data)
# Check whether ending is Z_SYNC_FLUSH
if len(data) < 4 or data[-4:] != b'\x00\x00\xff\xff':
return
msg = self.context.decompress(self.buffer)
self.buffer = bytearray()
return msg.decode('utf-8')
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZlibDecompressionContext

1
pyproject.toml

@ -56,6 +56,7 @@ speed = [
"aiodns>=1.1; sys_platform != 'win32'",
"Brotli",
"cchardet==2.1.7; python_version < '3.10'",
"zstandard>=0.23.0"
]
test = [
"coverage[toml]",

Loading…
Cancel
Save