|
@ -80,8 +80,16 @@ if TYPE_CHECKING: |
|
|
CS = TypeVar('CS', bound='ConnectionState') |
|
|
CS = TypeVar('CS', bound='ConnectionState') |
|
|
Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable] |
|
|
Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChunkRequest: |
|
|
class ChunkRequest: |
|
|
def __init__(self, guild_id: int, loop: asyncio.AbstractEventLoop, resolver: Callable[[int], Any], *, cache: bool = True) -> None: |
|
|
def __init__( |
|
|
|
|
|
self, |
|
|
|
|
|
guild_id: int, |
|
|
|
|
|
loop: asyncio.AbstractEventLoop, |
|
|
|
|
|
resolver: Callable[[int], Any], |
|
|
|
|
|
*, |
|
|
|
|
|
cache: bool = True, |
|
|
|
|
|
) -> None: |
|
|
self.guild_id: int = guild_id |
|
|
self.guild_id: int = guild_id |
|
|
self.resolver: Callable[[int], Any] = resolver |
|
|
self.resolver: Callable[[int], Any] = resolver |
|
|
self.loop: asyncio.AbstractEventLoop = loop |
|
|
self.loop: asyncio.AbstractEventLoop = loop |
|
@ -120,21 +128,33 @@ class ChunkRequest: |
|
|
if not future.done(): |
|
|
if not future.done(): |
|
|
future.set_result(self.buffer) |
|
|
future.set_result(self.buffer) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log: logging.Logger = logging.getLogger(__name__) |
|
|
log: logging.Logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) -> Optional[T]: |
|
|
async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) -> Optional[T]: |
|
|
try: |
|
|
try: |
|
|
await coroutine |
|
|
await coroutine |
|
|
except Exception: |
|
|
except Exception: |
|
|
log.exception('Exception occurred during %s', info) |
|
|
log.exception('Exception occurred during %s', info) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConnectionState: |
|
|
class ConnectionState: |
|
|
if TYPE_CHECKING: |
|
|
if TYPE_CHECKING: |
|
|
_get_websocket: Callable[..., DiscordWebSocket] |
|
|
_get_websocket: Callable[..., DiscordWebSocket] |
|
|
_get_client: Callable[..., Client] |
|
|
_get_client: Callable[..., Client] |
|
|
_parsers: Dict[str, Callable[[Dict[str, Any]], None]] |
|
|
_parsers: Dict[str, Callable[[Dict[str, Any]], None]] |
|
|
|
|
|
|
|
|
def __init__(self, *, dispatch: Callable, handlers: Dict[str, Callable], hooks: Dict[str, Callable], http: HTTPClient, loop: asyncio.AbstractEventLoop, **options: Any) -> None: |
|
|
def __init__( |
|
|
|
|
|
self, |
|
|
|
|
|
*, |
|
|
|
|
|
dispatch: Callable, |
|
|
|
|
|
handlers: Dict[str, Callable], |
|
|
|
|
|
hooks: Dict[str, Callable], |
|
|
|
|
|
http: HTTPClient, |
|
|
|
|
|
loop: asyncio.AbstractEventLoop, |
|
|
|
|
|
**options: Any, |
|
|
|
|
|
) -> None: |
|
|
self.loop: asyncio.AbstractEventLoop = loop |
|
|
self.loop: asyncio.AbstractEventLoop = loop |
|
|
self.http: HTTPClient = http |
|
|
self.http: HTTPClient = http |
|
|
self.max_messages: Optional[int] = options.get('max_messages', 1000) |
|
|
self.max_messages: Optional[int] = options.get('max_messages', 1000) |
|
@ -454,7 +474,9 @@ class ConnectionState: |
|
|
|
|
|
|
|
|
return channel or PartialMessageable(state=self, id=channel_id), guild |
|
|
return channel or PartialMessageable(state=self, id=channel_id), guild |
|
|
|
|
|
|
|
|
async def chunker(self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None) -> None: |
|
|
async def chunker( |
|
|
|
|
|
self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None |
|
|
|
|
|
) -> None: |
|
|
ws = self._get_websocket(guild_id) # This is ignored upstream |
|
|
ws = self._get_websocket(guild_id) # This is ignored upstream |
|
|
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) |
|
|
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) |
|
|
|
|
|
|
|
@ -469,7 +491,9 @@ class ConnectionState: |
|
|
|
|
|
|
|
|
try: |
|
|
try: |
|
|
# start the query operation |
|
|
# start the query operation |
|
|
await ws.request_chunks(guild_id, query=query, limit=limit, user_ids=user_ids, presences=presences, nonce=request.nonce) |
|
|
await ws.request_chunks( |
|
|
|
|
|
guild_id, query=query, limit=limit, user_ids=user_ids, presences=presences, nonce=request.nonce |
|
|
|
|
|
) |
|
|
return await asyncio.wait_for(request.wait(), timeout=30.0) |
|
|
return await asyncio.wait_for(request.wait(), timeout=30.0) |
|
|
except asyncio.TimeoutError: |
|
|
except asyncio.TimeoutError: |
|
|
log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id) |
|
|
log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id) |
|
@ -853,10 +877,7 @@ class ConnectionState: |
|
|
else: |
|
|
else: |
|
|
previous_threads = guild._filter_threads(channel_ids) |
|
|
previous_threads = guild._filter_threads(channel_ids) |
|
|
|
|
|
|
|
|
threads = { |
|
|
threads = {d['id']: guild._store_thread(d) for d in data.get('threads', [])} |
|
|
d['id']: guild._store_thread(d) |
|
|
|
|
|
for d in data.get('threads', []) |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
for member in data.get('members', []): |
|
|
for member in data.get('members', []): |
|
|
try: |
|
|
try: |
|
@ -995,7 +1016,7 @@ class ConnectionState: |
|
|
for emoji in before_emojis: |
|
|
for emoji in before_emojis: |
|
|
self._emojis.pop(emoji.id, None) |
|
|
self._emojis.pop(emoji.id, None) |
|
|
# guild won't be None here |
|
|
# guild won't be None here |
|
|
guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data['emojis'])) #type: ignore |
|
|
guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data['emojis'])) # type: ignore |
|
|
self.dispatch('guild_emojis_update', guild, before_emojis, guild.emojis) |
|
|
self.dispatch('guild_emojis_update', guild, before_emojis, guild.emojis) |
|
|
|
|
|
|
|
|
def parse_guild_stickers_update(self, data) -> None: |
|
|
def parse_guild_stickers_update(self, data) -> None: |
|
@ -1101,7 +1122,9 @@ class ConnectionState: |
|
|
|
|
|
|
|
|
# do a cleanup of the messages cache |
|
|
# do a cleanup of the messages cache |
|
|
if self._messages is not None: |
|
|
if self._messages is not None: |
|
|
self._messages: Optional[Deque[Message]] = deque((msg for msg in self._messages if msg.guild != guild), maxlen=self.max_messages) |
|
|
self._messages: Optional[Deque[Message]] = deque( |
|
|
|
|
|
(msg for msg in self._messages if msg.guild != guild), maxlen=self.max_messages |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
self._remove_guild(guild) |
|
|
self._remove_guild(guild) |
|
|
self.dispatch('guild_remove', guild) |
|
|
self.dispatch('guild_remove', guild) |
|
@ -1368,9 +1391,12 @@ class ConnectionState: |
|
|
if channel is not None: |
|
|
if channel is not None: |
|
|
return channel |
|
|
return channel |
|
|
|
|
|
|
|
|
def create_message(self, *, channel: Union[TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable], data: MessagePayload) -> Message: |
|
|
def create_message( |
|
|
|
|
|
self, *, channel: Union[TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable], data: MessagePayload |
|
|
|
|
|
) -> Message: |
|
|
return Message(state=self, channel=channel, data=data) |
|
|
return Message(state=self, channel=channel, data=data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AutoShardedConnectionState(ConnectionState): |
|
|
class AutoShardedConnectionState(ConnectionState): |
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None: |
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None: |
|
|
super().__init__(*args, **kwargs) |
|
|
super().__init__(*args, **kwargs) |
|
@ -1390,7 +1416,16 @@ class AutoShardedConnectionState(ConnectionState): |
|
|
# channel will either be a TextChannel, Thread or Object |
|
|
# channel will either be a TextChannel, Thread or Object |
|
|
msg._rebind_cached_references(new_guild, channel) # type: ignore |
|
|
msg._rebind_cached_references(new_guild, channel) # type: ignore |
|
|
|
|
|
|
|
|
async def chunker(self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, shard_id: Optional[int] = None, nonce: Optional[str] = None) -> None: |
|
|
async def chunker( |
|
|
|
|
|
self, |
|
|
|
|
|
guild_id: int, |
|
|
|
|
|
query: str = '', |
|
|
|
|
|
limit: int = 0, |
|
|
|
|
|
presences: bool = False, |
|
|
|
|
|
*, |
|
|
|
|
|
shard_id: Optional[int] = None, |
|
|
|
|
|
nonce: Optional[str] = None, |
|
|
|
|
|
) -> None: |
|
|
ws = self._get_websocket(guild_id, shard_id=shard_id) |
|
|
ws = self._get_websocket(guild_id, shard_id=shard_id) |
|
|
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) |
|
|
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) |
|
|
|
|
|
|
|
@ -1435,9 +1470,9 @@ class AutoShardedConnectionState(ConnectionState): |
|
|
try: |
|
|
try: |
|
|
await utils.sane_wait_for(futures, timeout=timeout) |
|
|
await utils.sane_wait_for(futures, timeout=timeout) |
|
|
except asyncio.TimeoutError: |
|
|
except asyncio.TimeoutError: |
|
|
log.warning('Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', shard_id, |
|
|
log.warning( |
|
|
timeout, |
|
|
'Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', shard_id, timeout, len(guilds) |
|
|
len(guilds)) |
|
|
) |
|
|
for guild in children: |
|
|
for guild in children: |
|
|
if guild.unavailable is False: |
|
|
if guild.unavailable is False: |
|
|
self.dispatch('guild_available', guild) |
|
|
self.dispatch('guild_available', guild) |
|
|