diff --git a/discord/utils.py b/discord/utils.py index a4aa2835a..986d166d3 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -74,11 +74,6 @@ import logging import yarl -if sys.version_info >= (3, 14): - import compression.zstd -else: - import zlib - try: import orjson # type: ignore except ModuleNotFoundError: @@ -87,11 +82,18 @@ else: HAS_ORJSON = True try: - import zstandard # type: ignore -except ImportError: - _HAS_ZSTD = False -else: + from zstandard import ZstdDecompressor # type: ignore + _HAS_ZSTD = True +except ImportError: + try: + from compression.zstd import ZstdDecompressor # type: ignore + except ImportError: + import zlib + + _HAS_ZSTD = False + else: + _HAS_ZSTD = True __all__ = ( 'oauth_url', @@ -1429,32 +1431,16 @@ def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'o if _HAS_ZSTD: class _ZstdDecompressionContext: - __slots__ = ('context',) - - COMPRESSION_TYPE: str = 'zstd-stream' - - def __init__(self) -> None: - decompressor = zstandard.ZstdDecompressor() - self.context = decompressor.decompressobj() - - def decompress(self, data: bytes, /) -> str | None: - # Each WS message is a complete gateway message - return self.context.decompress(data).decode('utf-8') - - _ActiveDecompressionContext: Type[_DecompressionContext] = _ZstdDecompressionContext -elif sys.version_info >= (3, 14): - - class _ZstdDecompressionContext: - __slots__ = ('context',) + __slots__ = ('decompressor',) COMPRESSION_TYPE: str = 'zstd-stream' def __init__(self) -> None: - self.context = compression.zstd.ZstdDecompressor() + self.decompressor = ZstdDecompressor() def decompress(self, data: bytes, /) -> str | None: # Each WS message is a complete gateway message - return self.context.decompress(data).decode('utf-8') + return self.decompressor.decompress(data).decode('utf-8') _ActiveDecompressionContext: Type[_DecompressionContext] = _ZstdDecompressionContext else: