|
|
@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
|
|
|
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
|
|
|
DEALINGS IN THE SOFTWARE. |
|
|
|
""" |
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
|
|
import array |
|
|
@ -41,7 +42,6 @@ from typing import ( |
|
|
|
Iterator, |
|
|
|
List, |
|
|
|
Literal, |
|
|
|
Mapping, |
|
|
|
NamedTuple, |
|
|
|
Optional, |
|
|
|
Protocol, |
|
|
@ -71,6 +71,7 @@ import types |
|
|
|
import typing |
|
|
|
import warnings |
|
|
|
import logging |
|
|
|
import zlib |
|
|
|
|
|
|
|
import yarl |
|
|
|
|
|
|
@ -81,6 +82,12 @@ except ModuleNotFoundError: |
|
|
|
else: |
|
|
|
HAS_ORJSON = True |
|
|
|
|
|
|
|
try: |
|
|
|
import zstandard # type: ignore |
|
|
|
except ImportError: |
|
|
|
_HAS_ZSTD = False |
|
|
|
else: |
|
|
|
_HAS_ZSTD = True |
|
|
|
|
|
|
|
__all__ = ( |
|
|
|
'oauth_url', |
|
|
@ -148,8 +155,11 @@ if TYPE_CHECKING: |
|
|
|
from .invite import Invite |
|
|
|
from .template import Template |
|
|
|
|
|
|
|
class _RequestLike(Protocol): |
|
|
|
headers: Mapping[str, Any] |
|
|
|
class _DecompressionContext(Protocol): |
|
|
|
COMPRESSION_TYPE: str |
|
|
|
|
|
|
|
def decompress(self, data: bytes, /) -> str | None: |
|
|
|
... |
|
|
|
|
|
|
|
P = ParamSpec('P') |
|
|
|
|
|
|
@ -1416,3 +1426,45 @@ def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'o |
|
|
|
return f'{seq[0]} {final} {seq[1]}' |
|
|
|
|
|
|
|
return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}' |
|
|
|
|
|
|
|
|
|
|
|
if _HAS_ZSTD: |
|
|
|
|
|
|
|
class _ZstdDecompressionContext: |
|
|
|
__slots__ = ('context',) |
|
|
|
|
|
|
|
COMPRESSION_TYPE: str = 'zstd-stream' |
|
|
|
|
|
|
|
def __init__(self) -> None: |
|
|
|
decompressor = zstandard.ZstdDecompressor() |
|
|
|
self.context = decompressor.decompressobj() |
|
|
|
|
|
|
|
def decompress(self, data: bytes, /) -> str | None: |
|
|
|
# Each WS message is a complete gateway message |
|
|
|
return self.context.decompress(data).decode('utf-8') |
|
|
|
|
|
|
|
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZstdDecompressionContext |
|
|
|
else: |
|
|
|
|
|
|
|
class _ZlibDecompressionContext: |
|
|
|
__slots__ = ('context', 'buffer') |
|
|
|
|
|
|
|
COMPRESSION_TYPE: str = 'zlib-stream' |
|
|
|
|
|
|
|
def __init__(self) -> None: |
|
|
|
self.buffer: bytearray = bytearray() |
|
|
|
self.context = zlib.decompressobj() |
|
|
|
|
|
|
|
def decompress(self, data: bytes, /) -> str | None: |
|
|
|
self.buffer.extend(data) |
|
|
|
|
|
|
|
# Check whether ending is Z_SYNC_FLUSH |
|
|
|
if len(data) < 4 or data[-4:] != b'\x00\x00\xff\xff': |
|
|
|
return |
|
|
|
|
|
|
|
msg = self.context.decompress(self.buffer) |
|
|
|
self.buffer = bytearray() |
|
|
|
|
|
|
|
return msg.decode('utf-8') |
|
|
|
|
|
|
|
_ActiveDecompressionContext: Type[_DecompressionContext] = _ZlibDecompressionContext |
|
|
|