diff --git a/discord/app_commands/models.py b/discord/app_commands/models.py index f2e478d29..4f4f42812 100644 --- a/discord/app_commands/models.py +++ b/discord/app_commands/models.py @@ -32,6 +32,8 @@ from ..enums import AppCommandOptionType, AppCommandType, AppCommandPermissionTy from ..mixins import Hashable from ..utils import _get_as_snowflake, parse_time, snowflake_time, MISSING from ..object import Object +from ..role import Role +from ..member import Member from typing import Any, Dict, Generic, List, TYPE_CHECKING, Optional, TypeVar, Union @@ -75,9 +77,7 @@ if TYPE_CHECKING: from ..guild import GuildChannel, Guild from ..channel import TextChannel from ..threads import Thread - from ..role import Role from ..user import User - from ..member import Member ApplicationCommandParent = Union['AppCommand', 'AppCommandGroup'] @@ -991,9 +991,11 @@ class AppCommandPermissions: self.permission: bool = data['permission'] _object = None + _type = MISSING if self.type is AppCommandPermissionType.user: _object = guild.get_member(self.id) or self._state.get_user(self.id) + _type = Member elif self.type is AppCommandPermissionType.channel: if self.id == (guild.id - 1): _object = AllChannels(guild) @@ -1001,9 +1003,10 @@ class AppCommandPermissions: _object = guild.get_channel(self.id) elif self.type is AppCommandPermissionType.role: _object = guild.get_role(self.id) + _type = Role if _object is None: - _object = Object(id=self.id) + _object = Object(id=self.id, type=_type) self.target: Union[Object, User, Member, Role, AllChannels, GuildChannel] = _object diff --git a/discord/audit_logs.py b/discord/audit_logs.py index 902e9d7d5..7b50b3e05 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -33,7 +33,16 @@ from .invite import Invite from .mixins import Hashable from .object import Object from .permissions import PermissionOverwrite, Permissions -from .automod import AutoModTrigger, AutoModRuleAction, AutoModPresets +from .automod import AutoModTrigger, AutoModRuleAction, AutoModPresets, AutoModRule +from .role import Role +from .emoji import Emoji +from .member import Member +from .scheduled_event import ScheduledEvent +from .stage_instance import StageInstance +from .sticker import GuildSticker +from .threads import Thread +from .integrations import PartialIntegration +from .channel import StageChannel __all__ = ( 'AuditLogDiff', @@ -46,11 +55,7 @@ if TYPE_CHECKING: import datetime from . import abc - from .emoji import Emoji from .guild import Guild - from .member import Member - from .role import Role - from .scheduled_event import ScheduledEvent from .state import ConnectionState from .types.audit_log import ( AuditLogChange as AuditLogChangePayload, @@ -65,12 +70,7 @@ if TYPE_CHECKING: from .types.command import ApplicationCommandPermissions from .types.automod import AutoModerationTriggerMetadata, AutoModerationAction from .user import User - from .stage_instance import StageInstance - from .sticker import GuildSticker - from .threads import Thread - from .integrations import PartialIntegration from .app_commands import AppCommand - from .automod import AutoModRule, AutoModTrigger TargetType = Union[ Guild, @@ -127,7 +127,7 @@ def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Opti def _transform_roles(entry: AuditLogEntry, data: List[Snowflake]) -> List[Union[Role, Object]]: - return [entry.guild.get_role(int(role_id)) or Object(role_id) for role_id in data] + return [entry.guild.get_role(int(role_id)) or Object(role_id, type=Role) for role_id in data] def _transform_overwrites( @@ -148,7 +148,7 @@ def _transform_overwrites( target = entry._get_member(ow_id) if target is None: - target = Object(id=ow_id) + target = Object(id=ow_id, type=Role if ow_type == '0' else Member) overwrites.append((target, ow)) @@ -390,7 +390,7 @@ class AuditLogChanges: role = g.get_role(role_id) if role is None: - role = Object(id=role_id) + role = Object(id=role_id, type=Role) role.name = e['name'] # type: ignore # Object doesn't usually have name data.append(role) @@ -581,17 +581,17 @@ class AuditLogEntry(Hashable): elif the_type == '0': role = self.guild.get_role(instance_id) if role is None: - role = Object(id=instance_id) + role = Object(id=instance_id, type=Role) role.name = extra.get('role_name') # type: ignore # Object doesn't usually have name self.extra = role elif self.action.name.startswith('stage_instance'): channel_id = int(extra['channel_id']) self.extra = _AuditLogProxyStageInstanceAction( - channel=self.guild.get_channel(channel_id) or Object(id=channel_id) + channel=self.guild.get_channel(channel_id) or Object(id=channel_id, type=StageChannel) ) elif self.action.name.startswith('app_command'): - application_id = int(extra['application_id']) - self.extra = self._get_integration_by_app_id(application_id) or Object(application_id) + app_id = int(extra['application_id']) + self.extra = self._get_integration_by_app_id(app_id) or Object(app_id, type=PartialIntegration) # this key is not present when the above is present, typically. # It's a list of { new_value: a, old_value: b, key: c } @@ -683,7 +683,7 @@ class AuditLogEntry(Hashable): return self._get_member(target_id) def _convert_target_role(self, target_id: int) -> Union[Role, Object]: - return self.guild.get_role(target_id) or Object(id=target_id) + return self.guild.get_role(target_id) or Object(id=target_id, type=Role) def _convert_target_invite(self, target_id: None) -> Invite: # invites have target_id set to null @@ -707,31 +707,52 @@ class AuditLogEntry(Hashable): return obj def _convert_target_emoji(self, target_id: int) -> Union[Emoji, Object]: - return self._state.get_emoji(target_id) or Object(id=target_id) + return self._state.get_emoji(target_id) or Object(id=target_id, type=Emoji) def _convert_target_message(self, target_id: int) -> Union[Member, User, None]: return self._get_member(target_id) def _convert_target_stage_instance(self, target_id: int) -> Union[StageInstance, Object]: - return self.guild.get_stage_instance(target_id) or Object(id=target_id) + return self.guild.get_stage_instance(target_id) or Object(id=target_id, type=StageInstance) def _convert_target_sticker(self, target_id: int) -> Union[GuildSticker, Object]: - return self._state.get_sticker(target_id) or Object(id=target_id) + return self._state.get_sticker(target_id) or Object(id=target_id, type=StageInstance) def _convert_target_thread(self, target_id: int) -> Union[Thread, Object]: - return self.guild.get_thread(target_id) or Object(id=target_id) + return self.guild.get_thread(target_id) or Object(id=target_id, type=Thread) def _convert_target_guild_scheduled_event(self, target_id: int) -> Union[ScheduledEvent, Object]: - return self.guild.get_scheduled_event(target_id) or Object(id=target_id) + return self.guild.get_scheduled_event(target_id) or Object(id=target_id, type=ScheduledEvent) def _convert_target_integration(self, target_id: int) -> Union[PartialIntegration, Object]: - return self._get_integration(target_id) or Object(target_id) + return self._get_integration(target_id) or Object(target_id, type=PartialIntegration) def _convert_target_app_command(self, target_id: int) -> Union[AppCommand, Object]: - return self._get_app_command(target_id) or Object(target_id) + target = self._get_app_command(target_id) + if not target: + # circular import + from .app_commands import AppCommand + + target = Object(target_id, type=AppCommand) + + return target def _convert_target_integration_or_app_command(self, target_id: int) -> Union[PartialIntegration, AppCommand, Object]: - return self._get_integration_by_app_id(target_id) or self._get_app_command(target_id) or Object(target_id) + target = self._get_integration_by_app_id(target_id) or self._get_app_command(target_id) + if not target: + try: + # get application id from extras + # if it matches target id, type should be integration + target_app = self.extra + # extra should be an Object or PartialIntegration + app_id = target_app.application_id if isinstance(target_app, PartialIntegration) else target_app.id # type: ignore + type = PartialIntegration if target_id == app_id else AppCommand + except AttributeError: + return Object(target_id) + else: + return Object(target_id, type=type) + + return target def _convert_target_auto_moderation(self, target_id: int) -> Union[AutoModRule, Object]: - return self._automod_rules.get(target_id) or Object(target_id) + return self._automod_rules.get(target_id) or Object(target_id, type=AutoModRule) diff --git a/discord/client.py b/discord/client.py index d14148477..b8adcf525 100644 --- a/discord/client.py +++ b/discord/client.py @@ -1885,8 +1885,8 @@ class Client: else: # the factory can't be a DMChannel or GroupChannel here guild_id = int(data['guild_id']) # type: ignore - guild = self.get_guild(guild_id) or Object(id=guild_id) - # GuildChannels expect a Guild, we may be passing an Object + guild = self._connection._get_or_create_unavailable_guild(guild_id) + # the factory should be a GuildChannel or Thread channel = factory(guild=guild, state=self._connection, data=data) # type: ignore return channel diff --git a/discord/invite.py b/discord/invite.py index 65c7412c5..13a516397 100644 --- a/discord/invite.py +++ b/discord/invite.py @@ -458,7 +458,7 @@ class Invite(Hashable): if guild is not None: channel = guild.get_channel(channel_id) or Object(id=channel_id) else: - guild = Object(id=guild_id) if guild_id is not None else None + guild = state._get_or_create_unavailable_guild(guild_id) if guild_id is not None else None channel = Object(id=channel_id) return cls(state=state, data=data, guild=guild, channel=channel) # type: ignore