From 311eac97b02b030f999f6b57fb4545fc0bb921f5 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sat, 21 Aug 2021 14:48:22 -0400 Subject: [PATCH] Reformat state.py --- discord/state.py | 137 +++++++++++++++++++++++++++++------------------ 1 file changed, 86 insertions(+), 51 deletions(-) diff --git a/discord/state.py b/discord/state.py index ce5fdac27..222f6c709 100644 --- a/discord/state.py +++ b/discord/state.py @@ -80,8 +80,16 @@ if TYPE_CHECKING: CS = TypeVar('CS', bound='ConnectionState') Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable] + class ChunkRequest: - def __init__(self, guild_id: int, loop: asyncio.AbstractEventLoop, resolver: Callable[[int], Any], *, cache: bool = True) -> None: + def __init__( + self, + guild_id: int, + loop: asyncio.AbstractEventLoop, + resolver: Callable[[int], Any], + *, + cache: bool = True, + ) -> None: self.guild_id: int = guild_id self.resolver: Callable[[int], Any] = resolver self.loop: asyncio.AbstractEventLoop = loop @@ -120,21 +128,33 @@ class ChunkRequest: if not future.done(): future.set_result(self.buffer) + log: logging.Logger = logging.getLogger(__name__) + async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) -> Optional[T]: try: await coroutine except Exception: log.exception('Exception occurred during %s', info) + class ConnectionState: if TYPE_CHECKING: _get_websocket: Callable[..., DiscordWebSocket] _get_client: Callable[..., Client] _parsers: Dict[str, Callable[[Dict[str, Any]], None]] - def __init__(self, *, dispatch: Callable, handlers: Dict[str, Callable], hooks: Dict[str, Callable], http: HTTPClient, loop: asyncio.AbstractEventLoop, **options: Any) -> None: + def __init__( + self, + *, + dispatch: Callable, + handlers: Dict[str, Callable], + hooks: Dict[str, Callable], + http: HTTPClient, + loop: asyncio.AbstractEventLoop, + **options: Any, + ) -> None: self.loop: asyncio.AbstractEventLoop = loop self.http: HTTPClient = http self.max_messages: Optional[int] = options.get('max_messages', 1000) @@ -292,7 +312,7 @@ class ConnectionState: def _get_voice_client(self, guild_id: Optional[int]) -> Optional[VoiceProtocol]: # the keys of self._voice_clients are ints - return self._voice_clients.get(guild_id) # type: ignore + return self._voice_clients.get(guild_id) # type: ignore def _add_voice_client(self, guild_id: int, voice: VoiceProtocol) -> None: self._voice_clients[guild_id] = voice @@ -302,7 +322,7 @@ class ConnectionState: def _update_references(self, ws: DiscordWebSocket) -> None: for vc in self.voice_clients: - vc.main_ws = ws # type: ignore + vc.main_ws = ws # type: ignore def store_user(self, data: UserPayload) -> User: user_id = int(data['id']) @@ -326,11 +346,11 @@ class ConnectionState: def get_user(self, id: Optional[int]) -> Optional[User]: # the keys of self._users are ints - return self._users.get(id) # type: ignore + return self._users.get(id) # type: ignore def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji: # the id will be present here - emoji_id = int(data['id']) # type: ignore + emoji_id = int(data['id']) # type: ignore self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self, data=data) return emoji @@ -355,7 +375,7 @@ class ConnectionState: def _get_guild(self, guild_id: Optional[int]) -> Optional[Guild]: # the keys of self._guilds are ints - return self._guilds.get(guild_id) # type: ignore + return self._guilds.get(guild_id) # type: ignore def _add_guild(self, guild: Guild) -> None: self._guilds[guild.id] = guild @@ -381,11 +401,11 @@ class ConnectionState: def get_emoji(self, emoji_id: Optional[int]) -> Optional[Emoji]: # the keys of self._emojis are ints - return self._emojis.get(emoji_id) # type: ignore + return self._emojis.get(emoji_id) # type: ignore def get_sticker(self, sticker_id: Optional[int]) -> Optional[GuildSticker]: # the keys of self._stickers are ints - return self._stickers.get(sticker_id) # type: ignore + return self._stickers.get(sticker_id) # type: ignore @property def private_channels(self) -> List[PrivateChannel]: @@ -394,16 +414,16 @@ class ConnectionState: def _get_private_channel(self, channel_id: Optional[int]) -> Optional[PrivateChannel]: try: # the keys of self._private_channels are ints - value = self._private_channels[channel_id] # type: ignore + value = self._private_channels[channel_id] # type: ignore except KeyError: return None else: - self._private_channels.move_to_end(channel_id) # type: ignore + self._private_channels.move_to_end(channel_id) # type: ignore return value def _get_private_channel_by_user(self, user_id: Optional[int]) -> Optional[PrivateChannel]: # the keys of self._private_channels are ints - return self._private_channels_by_user.get(user_id) # type: ignore + return self._private_channels_by_user.get(user_id) # type: ignore def _add_private_channel(self, channel: PrivateChannel) -> None: channel_id = channel.id @@ -419,7 +439,7 @@ class ConnectionState: def add_dm_channel(self, data: DMChannelPayload) -> DMChannel: # self.user is *always* cached when this is called - channel = DMChannel(me=self.user, state=self, data=data) # type: ignore + channel = DMChannel(me=self.user, state=self, data=data) # type: ignore self._add_private_channel(channel) return channel @@ -454,8 +474,10 @@ class ConnectionState: return channel or PartialMessageable(state=self, id=channel_id), guild - async def chunker(self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None) -> None: - ws = self._get_websocket(guild_id) # This is ignored upstream + async def chunker( + self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None + ) -> None: + ws = self._get_websocket(guild_id) # This is ignored upstream await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) async def query_members(self, guild: Guild, query: str, limit: int, user_ids: List[int], cache: bool, presences: bool): @@ -469,7 +491,9 @@ class ConnectionState: try: # start the query operation - await ws.request_chunks(guild_id, query=query, limit=limit, user_ids=user_ids, presences=presences, nonce=request.nonce) + await ws.request_chunks( + guild_id, query=query, limit=limit, user_ids=user_ids, presences=presences, nonce=request.nonce + ) return await asyncio.wait_for(request.wait(), timeout=30.0) except asyncio.TimeoutError: log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id) @@ -510,7 +534,7 @@ class ConnectionState: try: del self._ready_state except AttributeError: - pass # already been deleted somehow + pass # already been deleted somehow except asyncio.CancelledError: pass @@ -538,7 +562,7 @@ class ConnectionState: else: self.application_id = utils._get_as_snowflake(application, 'id') # flags will always be present here - self.application_flags = ApplicationFlags._from_value(application['flags']) # type: ignore + self.application_flags = ApplicationFlags._from_value(application['flags']) # type: ignore for guild_data in data['guilds']: self._add_guild_from_data(guild_data) @@ -552,13 +576,13 @@ class ConnectionState: def parse_message_create(self, data) -> None: channel, _ = self._get_guild_channel(data) # channel would be the correct type here - message = Message(channel=channel, data=data, state=self) # type: ignore + message = Message(channel=channel, data=data, state=self) # type: ignore self.dispatch('message', message) if self._messages is not None: self._messages.append(message) # we ensure that the channel is either a TextChannel or Thread if channel and channel.__class__ in (TextChannel, Thread): - channel.last_message_id = message.id # type: ignore + channel.last_message_id = message.id # type: ignore def parse_message_delete(self, data) -> None: raw = RawMessageDeleteEvent(data) @@ -581,7 +605,7 @@ class ConnectionState: self.dispatch('bulk_message_delete', found_messages) for msg in found_messages: # self._messages won't be None here - self._messages.remove(msg) # type: ignore + self._messages.remove(msg) # type: ignore def parse_message_update(self, data) -> None: raw = RawMessageUpdateEvent(data) @@ -650,7 +674,7 @@ class ConnectionState: emoji = self._upgrade_partial_emoji(emoji) try: reaction = message._remove_reaction(data, emoji, raw.user_id) - except (AttributeError, ValueError): # eventual consistency lol + except (AttributeError, ValueError): # eventual consistency lol pass else: user = self._get_reaction_user(message.channel, raw.user_id) @@ -668,7 +692,7 @@ class ConnectionState: if message is not None: try: reaction = message._clear_emoji(emoji) - except (AttributeError, ValueError): # eventual consistency lol + except (AttributeError, ValueError): # eventual consistency lol pass else: if reaction: @@ -707,7 +731,7 @@ class ConnectionState: def parse_user_update(self, data) -> None: # self.user is *always* cached when this is called - user: ClientUser = self.user # type: ignore + user: ClientUser = self.user # type: ignore user._update(data) ref = self._users.get(user.id) if ref: @@ -737,7 +761,7 @@ class ConnectionState: channel = self._get_private_channel(channel_id) old_channel = copy.copy(channel) # the channel is a GroupChannel - channel._update_group(data) # type: ignore + channel._update_group(data) # type: ignore self.dispatch('private_channel_update', old_channel, channel) return @@ -764,8 +788,8 @@ class ConnectionState: guild = self._get_guild(guild_id) if guild is not None: # the factory can't be a DMChannel or GroupChannel here - channel = factory(guild=guild, state=self, data=data) # type: ignore - guild._add_channel(channel) # type: ignore + channel = factory(guild=guild, state=self, data=data) # type: ignore + guild._add_channel(channel) # type: ignore self.dispatch('guild_channel_create', channel) else: log.debug('CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id) @@ -833,7 +857,7 @@ class ConnectionState: thread_id = int(data['id']) thread = guild.get_thread(thread_id) if thread is not None: - guild._remove_thread(thread) # type: ignore + guild._remove_thread(thread) # type: ignore self.dispatch('thread_delete', thread) def parse_thread_list_sync(self, data) -> None: @@ -853,10 +877,7 @@ class ConnectionState: else: previous_threads = guild._filter_threads(channel_ids) - threads = { - d['id']: guild._store_thread(d) - for d in data.get('threads', []) - } + threads = {d['id']: guild._store_thread(d) for d in data.get('threads', [])} for member in data.get('members', []): try: @@ -951,7 +972,7 @@ class ConnectionState: user_id = int(data['user']['id']) member = guild.get_member(user_id) if member is not None: - guild._remove_member(member) # type: ignore + guild._remove_member(member) # type: ignore self.dispatch('member_remove', member) else: log.debug('GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) @@ -995,7 +1016,7 @@ class ConnectionState: for emoji in before_emojis: self._emojis.pop(emoji.id, None) # guild won't be None here - guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data['emojis'])) #type: ignore + guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data['emojis'])) # type: ignore self.dispatch('guild_emojis_update', guild, before_emojis, guild.emojis) def parse_guild_stickers_update(self, data) -> None: @@ -1008,7 +1029,7 @@ class ConnectionState: for emoji in before_stickers: self._stickers.pop(emoji.id, None) # guild won't be None here - guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers'])) # type: ignore + guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers'])) # type: ignore self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers) def _get_create_guild(self, data): @@ -1101,7 +1122,9 @@ class ConnectionState: # do a cleanup of the messages cache if self._messages is not None: - self._messages: Optional[Deque[Message]] = deque((msg for msg in self._messages if msg.guild != guild), maxlen=self.max_messages) + self._messages: Optional[Deque[Message]] = deque( + (msg for msg in self._messages if msg.guild != guild), maxlen=self.max_messages + ) self._remove_guild(guild) self.dispatch('guild_remove', guild) @@ -1171,7 +1194,7 @@ class ConnectionState: presences = data.get('presences', []) # the guild won't be None here - members = [Member(guild=guild, data=member, state=self) for member in data.get('members', [])] # type: ignore + members = [Member(guild=guild, data=member, state=self) for member in data.get('members', [])] # type: ignore log.debug('Processed a chunk for %s members in guild ID %s.', len(members), guild_id) if presences: @@ -1273,7 +1296,7 @@ class ConnectionState: channel_id = utils._get_as_snowflake(data, 'channel_id') flags = self.member_cache_flags # self.user is *always* cached when this is called - self_id = self.user.id # type: ignore + self_id = self.user.id # type: ignore if guild is not None: if int(data['user_id']) == self_id: voice = self._get_voice_client(guild.id) @@ -1281,13 +1304,13 @@ class ConnectionState: coro = voice.on_voice_state_update(data) asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice state update handler')) - member, before, after = guild._update_voice_state(data, channel_id) # type: ignore + member, before, after = guild._update_voice_state(data, channel_id) # type: ignore if member is not None: if flags.voice: if channel_id is None and flags._voice_only and member.id != self_id: # Only remove from cache if we only have the voice flag enabled # Member doesn't meet the Snowflake protocol currently - guild._remove_member(member) # type: ignore + guild._remove_member(member) # type: ignore elif channel_id is not None: guild._add_member(member) @@ -1316,7 +1339,7 @@ class ConnectionState: elif isinstance(channel, (Thread, TextChannel)) and guild is not None: # user_id won't be None - member = guild.get_member(user_id) # type: ignore + member = guild.get_member(user_id) # type: ignore if member is None: member_data = data.get('member') @@ -1368,9 +1391,12 @@ class ConnectionState: if channel is not None: return channel - def create_message(self, *, channel: Union[TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable], data: MessagePayload) -> Message: + def create_message( + self, *, channel: Union[TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable], data: MessagePayload + ) -> Message: return Message(state=self, channel=channel, data=data) + class AutoShardedConnectionState(ConnectionState): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -1379,7 +1405,7 @@ class AutoShardedConnectionState(ConnectionState): def _update_message_references(self) -> None: # self._messages won't be None when this is called - for msg in self._messages: # type: ignore + for msg in self._messages: # type: ignore if not msg.guild: continue @@ -1388,9 +1414,18 @@ class AutoShardedConnectionState(ConnectionState): channel_id = msg.channel.id channel = new_guild._resolve_channel(channel_id) or Object(id=channel_id) # channel will either be a TextChannel, Thread or Object - msg._rebind_cached_references(new_guild, channel) # type: ignore - - async def chunker(self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, shard_id: Optional[int] = None, nonce: Optional[str] = None) -> None: + msg._rebind_cached_references(new_guild, channel) # type: ignore + + async def chunker( + self, + guild_id: int, + query: str = '', + limit: int = 0, + presences: bool = False, + *, + shard_id: Optional[int] = None, + nonce: Optional[str] = None, + ) -> None: ws = self._get_websocket(guild_id, shard_id=shard_id) await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) @@ -1435,9 +1470,9 @@ class AutoShardedConnectionState(ConnectionState): try: await utils.sane_wait_for(futures, timeout=timeout) except asyncio.TimeoutError: - log.warning('Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', shard_id, - timeout, - len(guilds)) + log.warning( + 'Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', shard_id, timeout, len(guilds) + ) for guild in children: if guild.unavailable is False: self.dispatch('guild_available', guild) @@ -1450,7 +1485,7 @@ class AutoShardedConnectionState(ConnectionState): try: del self._ready_state except AttributeError: - pass # already been deleted somehow + pass # already been deleted somehow # regular users cannot shard so we won't worry about it here. @@ -1467,7 +1502,7 @@ class AutoShardedConnectionState(ConnectionState): self.user = user = ClientUser(state=self, data=data['user']) # self._users is a list of Users, we're setting a ClientUser - self._users[user.id] = user # type: ignore + self._users[user.id] = user # type: ignore if self.application_id is None: try: