diff --git a/discord/gateway.py b/discord/gateway.py index e6fb7d8bf..13a213ce3 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations import asyncio @@ -32,7 +33,6 @@ import sys import time import threading import traceback -import zlib from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple @@ -325,8 +325,7 @@ class DiscordWebSocket: # ws related stuff self.session_id: Optional[str] = None self.sequence: Optional[int] = None - self._zlib: zlib._Decompress = zlib.decompressobj() - self._buffer: bytearray = bytearray() + self._decompressor: utils._DecompressionContext = utils._ActiveDecompressionContext() self._close_code: Optional[int] = None self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter() @@ -355,7 +354,7 @@ class DiscordWebSocket: sequence: Optional[int] = None, resume: bool = False, encoding: str = 'json', - zlib: bool = True, + compress: bool = True, ) -> Self: """Creates a main websocket for Discord from a :class:`Client`. @@ -366,10 +365,12 @@ class DiscordWebSocket: gateway = gateway or cls.DEFAULT_GATEWAY - if zlib: - url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream') - else: + if not compress: url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding) + else: + url = gateway.with_query( + v=INTERNAL_API_VERSION, encoding=encoding, compress=utils._ActiveDecompressionContext.COMPRESSION_TYPE + ) socket = await client.http.ws_connect(str(url)) ws = cls(socket, loop=client.loop) @@ -488,13 +489,11 @@ class DiscordWebSocket: async def received_message(self, msg: Any, /) -> None: if type(msg) is bytes: - self._buffer.extend(msg) + msg = self._decompressor.decompress(msg) - if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff': + # Received a partial gateway message + if msg is None: return - msg = self._zlib.decompress(self._buffer) - msg = msg.decode('utf-8') - self._buffer = bytearray() self.log_receive(msg) msg = utils._from_json(msg) diff --git a/discord/http.py b/discord/http.py index 24605c4fc..3c1eacb61 100644 --- a/discord/http.py +++ b/discord/http.py @@ -2701,28 +2701,13 @@ class HTTPClient: # Misc - async def get_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> str: - try: - data = await self.request(Route('GET', '/gateway')) - except HTTPException as exc: - raise GatewayNotFound() from exc - if zlib: - value = '{0}?encoding={1}&v={2}&compress=zlib-stream' - else: - value = '{0}?encoding={1}&v={2}' - return value.format(data['url'], encoding, INTERNAL_API_VERSION) - - async def get_bot_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> Tuple[int, str]: + async def get_bot_gateway(self) -> Tuple[int, str]: try: data = await self.request(Route('GET', '/gateway/bot')) except HTTPException as exc: raise GatewayNotFound() from exc - if zlib: - value = '{0}?encoding={1}&v={2}&compress=zlib-stream' - else: - value = '{0}?encoding={1}&v={2}' - return data['shards'], value.format(data['url'], encoding, INTERNAL_API_VERSION) + return data['shards'], data['url'] def get_user(self, user_id: Snowflake) -> Response[user.User]: return self.request(Route('GET', '/users/{user_id}', user_id=user_id)) diff --git a/discord/utils.py b/discord/utils.py index ee4097c46..cb7d662b6 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations import array @@ -41,7 +42,6 @@ from typing import ( Iterator, List, Literal, - Mapping, NamedTuple, Optional, Protocol, @@ -71,6 +71,7 @@ import types import typing import warnings import logging +import zlib import yarl @@ -81,6 +82,12 @@ except ModuleNotFoundError: else: HAS_ORJSON = True +try: + import zstandard # type: ignore +except ImportError: + _HAS_ZSTD = False +else: + _HAS_ZSTD = True __all__ = ( 'oauth_url', @@ -148,8 +155,11 @@ if TYPE_CHECKING: from .invite import Invite from .template import Template - class _RequestLike(Protocol): - headers: Mapping[str, Any] + class _DecompressionContext(Protocol): + COMPRESSION_TYPE: str + + def decompress(self, data: bytes, /) -> str | None: + ... P = ParamSpec('P') @@ -1416,3 +1426,45 @@ def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'o return f'{seq[0]} {final} {seq[1]}' return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}' + + +if _HAS_ZSTD: + + class _ZstdDecompressionContext: + __slots__ = ('context',) + + COMPRESSION_TYPE: str = 'zstd-stream' + + def __init__(self) -> None: + decompressor = zstandard.ZstdDecompressor() + self.context = decompressor.decompressobj() + + def decompress(self, data: bytes, /) -> str | None: + # Each WS message is a complete gateway message + return self.context.decompress(data).decode('utf-8') + + _ActiveDecompressionContext: Type[_DecompressionContext] = _ZstdDecompressionContext +else: + + class _ZlibDecompressionContext: + __slots__ = ('context', 'buffer') + + COMPRESSION_TYPE: str = 'zlib-stream' + + def __init__(self) -> None: + self.buffer: bytearray = bytearray() + self.context = zlib.decompressobj() + + def decompress(self, data: bytes, /) -> str | None: + self.buffer.extend(data) + + # Check whether ending is Z_SYNC_FLUSH + if len(data) < 4 or data[-4:] != b'\x00\x00\xff\xff': + return + + msg = self.context.decompress(self.buffer) + self.buffer = bytearray() + + return msg.decode('utf-8') + + _ActiveDecompressionContext: Type[_DecompressionContext] = _ZlibDecompressionContext diff --git a/pyproject.toml b/pyproject.toml index 596e6ef08..4ec7bc007 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ speed = [ "aiodns>=1.1; sys_platform != 'win32'", "Brotli", "cchardet==2.1.7; python_version < '3.10'", + "zstandard>=0.23.0" ] test = [ "coverage[toml]",