Browse Source

Add zstd gateway compression to speed profile

pull/9963/head
Lilly Rose Berner 6 months ago
committed by GitHub
parent
commit
91f300a28a
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 23
      discord/gateway.py
  2. 19
      discord/http.py
  3. 58
      discord/utils.py
  4. 1
      pyproject.toml

23
discord/gateway.py

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@ -32,7 +33,6 @@ import sys
import time import time
import threading import threading
import traceback import traceback
import zlib
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple
@ -325,8 +325,7 @@ class DiscordWebSocket:
# ws related stuff # ws related stuff
self.session_id: Optional[str] = None self.session_id: Optional[str] = None
self.sequence: Optional[int] = None self.sequence: Optional[int] = None
self._zlib: zlib._Decompress = zlib.decompressobj() self._decompressor: utils._DecompressionContext = utils._ActiveDecompressionContext()
self._buffer: bytearray = bytearray()
self._close_code: Optional[int] = None self._close_code: Optional[int] = None
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter() self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
@ -355,7 +354,7 @@ class DiscordWebSocket:
sequence: Optional[int] = None, sequence: Optional[int] = None,
resume: bool = False, resume: bool = False,
encoding: str = 'json', encoding: str = 'json',
zlib: bool = True, compress: bool = True,
) -> Self: ) -> Self:
"""Creates a main websocket for Discord from a :class:`Client`. """Creates a main websocket for Discord from a :class:`Client`.
@ -366,10 +365,12 @@ class DiscordWebSocket:
gateway = gateway or cls.DEFAULT_GATEWAY gateway = gateway or cls.DEFAULT_GATEWAY
if zlib: if not compress:
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream')
else:
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding) url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding)
else:
url = gateway.with_query(
v=INTERNAL_API_VERSION, encoding=encoding, compress=utils._ActiveDecompressionContext.COMPRESSION_TYPE
)
socket = await client.http.ws_connect(str(url)) socket = await client.http.ws_connect(str(url))
ws = cls(socket, loop=client.loop) ws = cls(socket, loop=client.loop)
@ -488,13 +489,11 @@ class DiscordWebSocket:
async def received_message(self, msg: Any, /) -> None: async def received_message(self, msg: Any, /) -> None:
if type(msg) is bytes: if type(msg) is bytes:
self._buffer.extend(msg) msg = self._decompressor.decompress(msg)
if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff': # Received a partial gateway message
if msg is None:
return return
msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8')
self._buffer = bytearray()
self.log_receive(msg) self.log_receive(msg)
msg = utils._from_json(msg) msg = utils._from_json(msg)

19
discord/http.py

@ -2701,28 +2701,13 @@ class HTTPClient:
# Misc # Misc
async def get_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> str: async def get_bot_gateway(self) -> Tuple[int, str]:
try:
data = await self.request(Route('GET', '/gateway'))
except HTTPException as exc:
raise GatewayNotFound() from exc
if zlib:
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
else:
value = '{0}?encoding={1}&v={2}'
return value.format(data['url'], encoding, INTERNAL_API_VERSION)
async def get_bot_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> Tuple[int, str]:
try: try:
data = await self.request(Route('GET', '/gateway/bot')) data = await self.request(Route('GET', '/gateway/bot'))
except HTTPException as exc: except HTTPException as exc:
raise GatewayNotFound() from exc raise GatewayNotFound() from exc
if zlib: return data['shards'], data['url']
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
else:
value = '{0}?encoding={1}&v={2}'
return data['shards'], value.format(data['url'], encoding, INTERNAL_API_VERSION)
def get_user(self, user_id: Snowflake) -> Response[user.User]: def get_user(self, user_id: Snowflake) -> Response[user.User]:
return self.request(Route('GET', '/users/{user_id}', user_id=user_id)) return self.request(Route('GET', '/users/{user_id}', user_id=user_id))

58
discord/utils.py

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
import array import array
@ -41,7 +42,6 @@ from typing import (
Iterator, Iterator,
List, List,
Literal, Literal,
Mapping,
NamedTuple, NamedTuple,
Optional, Optional,
Protocol, Protocol,
@ -71,6 +71,7 @@ import types
import typing import typing
import warnings import warnings
import logging import logging
import zlib
import yarl import yarl
@ -81,6 +82,12 @@ except ModuleNotFoundError:
else: else:
HAS_ORJSON = True HAS_ORJSON = True
try:
import zstandard # type: ignore
except ImportError:
_HAS_ZSTD = False
else:
_HAS_ZSTD = True
__all__ = ( __all__ = (
'oauth_url', 'oauth_url',
@ -148,8 +155,11 @@ if TYPE_CHECKING:
from .invite import Invite from .invite import Invite
from .template import Template from .template import Template
class _RequestLike(Protocol): class _DecompressionContext(Protocol):
headers: Mapping[str, Any] COMPRESSION_TYPE: str
def decompress(self, data: bytes, /) -> str | None:
...
P = ParamSpec('P') P = ParamSpec('P')
@ -1416,3 +1426,45 @@ def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'o
return f'{seq[0]} {final} {seq[1]}' return f'{seq[0]} {final} {seq[1]}'
return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}' return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}'
if _HAS_ZSTD:
class _ZstdDecompressionContext:
__slots__ = ('context',)
COMPRESSION_TYPE: str = 'zstd-stream'
def __init__(self) -> None:
decompressor = zstandard.ZstdDecompressor()
self.context = decompressor.decompressobj()
def decompress(self, data: bytes, /) -> str | None:
# Each WS message is a complete gateway message
return self.context.decompress(data).decode('utf-8')
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZstdDecompressionContext
else:
class _ZlibDecompressionContext:
__slots__ = ('context', 'buffer')
COMPRESSION_TYPE: str = 'zlib-stream'
def __init__(self) -> None:
self.buffer: bytearray = bytearray()
self.context = zlib.decompressobj()
def decompress(self, data: bytes, /) -> str | None:
self.buffer.extend(data)
# Check whether ending is Z_SYNC_FLUSH
if len(data) < 4 or data[-4:] != b'\x00\x00\xff\xff':
return
msg = self.context.decompress(self.buffer)
self.buffer = bytearray()
return msg.decode('utf-8')
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZlibDecompressionContext

1
pyproject.toml

@ -56,6 +56,7 @@ speed = [
"aiodns>=1.1; sys_platform != 'win32'", "aiodns>=1.1; sys_platform != 'win32'",
"Brotli", "Brotli",
"cchardet==2.1.7; python_version < '3.10'", "cchardet==2.1.7; python_version < '3.10'",
"zstandard>=0.23.0"
] ]
test = [ test = [
"coverage[toml]", "coverage[toml]",

Loading…
Cancel
Save