diff --git a/discord/audit_logs.py b/discord/audit_logs.py index d684d6926..479b9c094 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -33,7 +33,15 @@ from .invite import Invite from .mixins import Hashable from .object import Object from .permissions import PermissionOverwrite, Permissions -from .automod import AutoModTrigger, AutoModRuleAction, AutoModPresets +from .automod import AutoModTrigger, AutoModRuleAction, AutoModPresets, AutoModRule +from .role import Role +from .emoji import Emoji +from .member import Member +from .scheduled_event import ScheduledEvent +from .stage_instance import StageInstance +from .sticker import GuildSticker +from .threads import Thread +from .channel import StageChannel __all__ = ( 'AuditLogDiff', @@ -46,11 +54,7 @@ if TYPE_CHECKING: import datetime from . import abc - from .emoji import Emoji from .guild import Guild - from .member import Member - from .role import Role - from .scheduled_event import ScheduledEvent from .state import ConnectionState from .types.audit_log import ( AuditLogChange as AuditLogChangePayload, @@ -123,7 +127,7 @@ def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Opti def _transform_roles(entry: AuditLogEntry, data: List[Snowflake]) -> List[Union[Role, Object]]: - return [entry.guild.get_role(int(role_id)) or Object(role_id) for role_id in data] + return [entry.guild.get_role(int(role_id)) or Object(role_id, type=Role) for role_id in data] def _transform_overwrites( @@ -144,7 +148,7 @@ def _transform_overwrites( target = entry._get_member(ow_id) if target is None: - target = Object(id=ow_id) + target = Object(id=ow_id, type=Role if ow_type == '0' else Member) overwrites.append((target, ow)) @@ -366,7 +370,7 @@ class AuditLogChanges: role = g.get_role(role_id) if role is None: - role = Object(id=role_id) + role = Object(id=role_id, type=Role) role.name = e['name'] # type: ignore # Object doesn't usually have name data.append(role) @@ -537,13 +541,13 @@ class AuditLogEntry(Hashable): elif the_type == '0': role = self.guild.get_role(instance_id) if role is None: - role = Object(id=instance_id) + role = Object(id=instance_id, type=Role) role.name = extra.get('role_name') # type: ignore # Object doesn't usually have name self.extra = role elif self.action.name.startswith('stage_instance'): channel_id = int(extra['channel_id']) self.extra = _AuditLogProxyStageInstanceAction( - channel=self.guild.get_channel(channel_id) or Object(id=channel_id) + channel=self.guild.get_channel(channel_id) or Object(id=channel_id, type=StageChannel) ) # this key is not present when the above is present, typically. @@ -617,7 +621,7 @@ class AuditLogEntry(Hashable): return self._get_member(target_id) def _convert_target_role(self, target_id: int) -> Union[Role, Object]: - return self.guild.get_role(target_id) or Object(id=target_id) + return self.guild.get_role(target_id) or Object(id=target_id, type=Role) def _convert_target_invite(self, target_id: None) -> Invite: # Invites have target_id set to null @@ -641,22 +645,22 @@ class AuditLogEntry(Hashable): return obj def _convert_target_emoji(self, target_id: int) -> Union[Emoji, Object]: - return self._state.get_emoji(target_id) or Object(id=target_id) + return self._state.get_emoji(target_id) or Object(id=target_id, type=Emoji) def _convert_target_message(self, target_id: int) -> Union[Member, User, None]: return self._get_member(target_id) def _convert_target_stage_instance(self, target_id: int) -> Union[StageInstance, Object]: - return self.guild.get_stage_instance(target_id) or Object(id=target_id) + return self.guild.get_stage_instance(target_id) or Object(id=target_id, type=StageInstance) def _convert_target_sticker(self, target_id: int) -> Union[GuildSticker, Object]: - return self._state.get_sticker(target_id) or Object(id=target_id) + return self._state.get_sticker(target_id) or Object(id=target_id, type=StageInstance) def _convert_target_thread(self, target_id: int) -> Union[Thread, Object]: - return self.guild.get_thread(target_id) or Object(id=target_id) + return self.guild.get_thread(target_id) or Object(id=target_id, type=Thread) def _convert_target_guild_scheduled_event(self, target_id: int) -> Union[ScheduledEvent, Object]: - return self.guild.get_scheduled_event(target_id) or Object(id=target_id) + return self.guild.get_scheduled_event(target_id) or Object(id=target_id, type=ScheduledEvent) def _convert_target_auto_moderation(self, target_id: int) -> Union[AutoModRule, Object]: - return self._automod_rules.get(target_id) or Object(target_id) + return self._automod_rules.get(target_id) or Object(target_id, type=AutoModRule) diff --git a/discord/client.py b/discord/client.py index e77e84ba3..75bbcb936 100644 --- a/discord/client.py +++ b/discord/client.py @@ -2328,8 +2328,8 @@ class Client: else: # The factory can't be a DMChannel or GroupChannel here guild_id = int(data['guild_id']) # type: ignore - guild = self.get_guild(guild_id) or Object(id=guild_id) - # GuildChannels expect a Guild, we may be passing an Object + guild = self._connection._get_or_create_unavailable_guild(guild_id) + # the factory should be a GuildChannel or Thread channel = factory(guild=guild, state=self._connection, data=data) # type: ignore return channel diff --git a/discord/guild.py b/discord/guild.py index 31910066c..77a2148e5 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -457,6 +457,10 @@ class Guild(Hashable): return role + @classmethod + def _create_unavailable(cls, *, state: ConnectionState, guild_id: int) -> Guild: + return cls(state=state, data={'id': guild_id, 'unavailable': True}) # type: ignore + def _from_data(self, guild: Union[GuildPayload, PartialGuildPayload]) -> None: try: self._member_count: int = guild['member_count'] # type: ignore # Handled below diff --git a/discord/invite.py b/discord/invite.py index cf27c731c..835a13a34 100644 --- a/discord/invite.py +++ b/discord/invite.py @@ -547,8 +547,8 @@ class Invite(Hashable): if guild is not None: channel = (guild.get_channel(channel_id) or Object(id=channel_id)) if channel_id is not None else None else: - guild = Object(id=guild_id) if guild_id is not None else None - channel = Object(id=channel_id) if channel_id is not None else None + guild = state._get_or_create_unavailable_guild(guild_id) if guild_id is not None else None + channel = Object(id=channel_id) return cls(state=state, data=data, guild=guild, channel=channel) # type: ignore diff --git a/discord/state.py b/discord/state.py index 795dd7482..7d1c235fb 100644 --- a/discord/state.py +++ b/discord/state.py @@ -772,6 +772,9 @@ class ConnectionState: guild = self._queued_guilds.get(guild_id) # type: ignore return guild + def _get_or_create_unavailable_guild(self, guild_id: int) -> Guild: + return self._guilds.get(guild_id) or Guild._create_unavailable(state=self, guild_id=guild_id) + def _add_guild(self, guild: Guild) -> None: self._guilds[guild.id] = guild