Browse Source

Make access_private_channel() consistent with the client

pull/10109/head
dolfies 3 years ago
parent
commit
51eb06353a
  1. 14
      discord/channel.py
  2. 15
      discord/state.py

14
discord/channel.py

@ -2227,7 +2227,7 @@ class DMChannel(discord.abc.Messageable, discord.abc.Connectable, Hashable):
The direct message channel ID.
"""
__slots__ = ('id', 'recipient', 'me', 'last_message_id', '_state')
__slots__ = ('id', 'recipient', 'me', 'last_message_id', '_state', '_accessed')
def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload):
self._state: ConnectionState = state
@ -2235,6 +2235,7 @@ class DMChannel(discord.abc.Messageable, discord.abc.Connectable, Hashable):
self.recipient: User = state.store_user(data['recipients'][0])
self.me: ClientUser = me
self.id: int = int(data['id'])
self._accessed: bool = False
def _get_voice_client_key(self) -> Tuple[int, str]:
return self.me.id, 'self_id'
@ -2246,7 +2247,9 @@ class DMChannel(discord.abc.Messageable, discord.abc.Connectable, Hashable):
return PrivateCall(**kwargs)
async def _get_channel(self) -> Self:
await self._state.access_private_channel(self.id)
if not self._accessed:
await self._state.access_private_channel(self.id)
self._accessed = True
return self
async def _initial_ring(self) -> None:
@ -2479,7 +2482,7 @@ class GroupChannel(discord.abc.Messageable, discord.abc.Connectable, Hashable):
The group channel's name if provided.
"""
__slots__ = ('last_message_id', 'id', 'recipients', 'owner_id', 'owner', '_icon', 'name', 'me', '_state')
__slots__ = ('last_message_id', 'id', 'recipients', 'owner_id', 'owner', '_icon', 'name', 'me', '_state', '_accessed')
def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload):
self._state: ConnectionState = state
@ -2493,6 +2496,7 @@ class GroupChannel(discord.abc.Messageable, discord.abc.Connectable, Hashable):
self.name: Optional[str] = data.get('name')
self.recipients: List[User] = [self._state.store_user(u) for u in data.get('recipients', [])]
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id')
self._accessed: bool = False
self.owner: Optional[BaseUser]
if self.owner_id == self.me.id:
@ -2507,7 +2511,9 @@ class GroupChannel(discord.abc.Messageable, discord.abc.Connectable, Hashable):
return self.me.id, self.id
async def _get_channel(self) -> Self:
await self._state.access_private_channel(self.id)
if not self._accessed:
await self._state.access_private_channel(self.id)
self._accessed = True
return self
def _initial_ring(self):

15
discord/state.py

@ -46,8 +46,6 @@ from typing import (
)
import weakref
import inspect
import time
import random
from math import ceil
from .errors import NotFound
@ -472,7 +470,6 @@ class ConnectionState:
self._relationships: Dict[int, Relationship] = {}
self._private_channels: Dict[int, PrivateChannel] = {}
self._private_channels_by_user: Dict[int, DMChannel] = {}
self._last_private_channel: tuple = (None, None)
if self.max_messages is not None:
self._messages: Optional[Deque[Message]] = deque(maxlen=self.max_messages)
@ -655,11 +652,6 @@ class ConnectionState:
return list(self._private_channels.values())
async def access_private_channel(self, channel_id: int) -> None:
if not self._get_accessed_private_channel(channel_id):
await self._access_private_channel(channel_id)
self._set_accessed_private_channel(channel_id)
async def _access_private_channel(self, channel_id: int) -> None:
if (ws := self.ws) is None:
return
@ -668,13 +660,6 @@ class ConnectionState:
except Exception as exc:
_log.warning('Sending ACCESS_DM failed for channel %s, (%s).', channel_id, exc)
def _set_accessed_private_channel(self, channel_id):
self._last_private_channel = (channel_id, time.time())
def _get_accessed_private_channel(self, channel_id):
timestamp, existing_id = self._last_private_channel
return existing_id == channel_id and int(time.time() - timestamp) < random.randrange(120000, 420000)
def _get_private_channel(self, channel_id: Optional[int]) -> Optional[PrivateChannel]:
# The keys of self._private_channels are ints
return self._private_channels.get(channel_id) # type: ignore

Loading…
Cancel
Save