diff --git a/discord/audit_logs.py b/discord/audit_logs.py index f6a931509..ed17f2607 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -24,14 +24,15 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Dict, List, TYPE_CHECKING -from . import utils, enums -from .object import Object -from .permissions import PermissionOverwrite, Permissions +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, List, Optional, Tuple, Union + +from . import enums, utils +from .asset import Asset from .colour import Colour from .invite import Invite from .mixins import Hashable -from .asset import Asset +from .object import Object +from .permissions import PermissionOverwrite, Permissions __all__ = ( 'AuditLogDiff', @@ -39,58 +40,68 @@ __all__ = ( 'AuditLogEntry', ) + if TYPE_CHECKING: - from .types.audit_log import ( - AuditLogChange as AuditLogChangePayload, - AuditLogEntry as AuditLogEntryPayload, - ) + import datetime + + from . import abc + from .emoji import Emoji from .guild import Guild + from .member import Member + from .role import Role + from .types.audit_log import AuditLogChange as AuditLogChangePayload + from .types.audit_log import AuditLogEntry as AuditLogEntryPayload + from .types.channel import PermissionOverwrite as PermissionOverwritePayload + from .types.role import Role as RolePayload + from .types.snowflake import Snowflake from .user import User -def _transform_verification_level(entry, data): +def _transform_verification_level(entry: AuditLogEntry, data: int) -> enums.VerificationLevel: return enums.try_enum(enums.VerificationLevel, data) -def _transform_default_notifications(entry, data): +def _transform_default_notifications(entry: AuditLogEntry, data: int) -> enums.NotificationLevel: return enums.try_enum(enums.NotificationLevel, data) -def _transform_explicit_content_filter(entry, data): +def _transform_explicit_content_filter(entry: AuditLogEntry, data: int) -> enums.ContentFilter: return enums.try_enum(enums.ContentFilter, data) -def _transform_permissions(entry, data): +def _transform_permissions(entry: AuditLogEntry, data: str) -> Permissions: return Permissions(int(data)) -def _transform_color(entry, data): +def _transform_color(entry: AuditLogEntry, data: int) -> Colour: return Colour(data) -def _transform_snowflake(entry, data): +def _transform_snowflake(entry: AuditLogEntry, data: Snowflake) -> int: return int(data) -def _transform_channel(entry, data): +def _transform_channel(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Object]: if data is None: return None return entry.guild.get_channel(int(data)) or Object(id=data) -def _transform_owner_id(entry, data): +def _transform_owner_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]: if data is None: return None return entry._get_member(int(data)) -def _transform_inviter_id(entry, data): +def _transform_inviter_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]: if data is None: return None return entry._get_member(int(data)) -def _transform_overwrites(entry, data): +def _transform_overwrites( + entry: AuditLogEntry, data: List[PermissionOverwritePayload] +) -> List[Tuple[Object, PermissionOverwrite]]: overwrites = [] for elem in data: allow = Permissions(elem['allow']) @@ -113,32 +124,32 @@ def _transform_overwrites(entry, data): return overwrites -def _transform_channeltype(entry, data): +def _transform_channeltype(entry: AuditLogEntry, data: int) -> enums.ChannelType: return enums.try_enum(enums.ChannelType, data) -def _transform_voiceregion(entry, data): +def _transform_voiceregion(entry: AuditLogEntry, data: int) -> enums.VoiceRegion: return enums.try_enum(enums.VoiceRegion, data) -def _transform_video_quality_mode(entry, data): +def _transform_video_quality_mode(entry: AuditLogEntry, data: int) -> enums.VideoQualityMode: return enums.try_enum(enums.VideoQualityMode, data) -def _transform_icon(entry, data): +def _transform_icon(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset]: if data is None: return None return Asset._from_guild_icon(entry._state, entry.guild.id, data) -def _transform_avatar(entry, data): +def _transform_avatar(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset]: if data is None: return None return Asset._from_avatar(entry._state, entry._target_id, data) -def _guild_hash_transformer(path): - def _transform(entry, data): +def _guild_hash_transformer(path: str) -> Callable[['AuditLogEntry', Optional[str]], Optional[Asset]]: + def _transform(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset]: if data is None: return None return Asset._from_guild_image(entry._state, entry.guild.id, data, path=path) @@ -147,20 +158,31 @@ def _guild_hash_transformer(path): class AuditLogDiff: - def __len__(self): + def __len__(self) -> int: return len(self.__dict__) - def __iter__(self): - return iter(self.__dict__.items()) + def __iter__(self) -> Generator[Tuple[str, Any], None, None]: + yield from self.__dict__.items() - def __repr__(self): + def __repr__(self) -> str: values = ' '.join('%s=%r' % item for item in self.__dict__.items()) return f'' + if TYPE_CHECKING: + + def __getattr__(self, item: str) -> Any: + ... + + def __setattr__(self, key: str, value: Any) -> Any: + ... + + +Transformer = Callable[["AuditLogEntry", Any], Any] + class AuditLogChanges: # fmt: off - TRANSFORMERS = { + TRANSFORMERS: ClassVar[Dict[str, Tuple[Optional[str], Optional[Transformer]]]] = { 'verification_level': (None, _transform_verification_level), 'explicit_content_filter': (None, _transform_explicit_content_filter), 'allow': (None, _transform_permissions), @@ -191,7 +213,7 @@ class AuditLogChanges: } # fmt: on - def __init__(self, entry, data: List[AuditLogChangePayload]): + def __init__(self, entry: AuditLogEntry, data: List[AuditLogChangePayload]): self.before = AuditLogDiff() self.after = AuditLogDiff() @@ -206,12 +228,16 @@ class AuditLogChanges: self._handle_role(self.after, self.before, entry, elem['new_value']) continue - transformer = self.TRANSFORMERS.get(attr) - if transformer: - key, transformer = transformer + try: + key, transformer = self.TRANSFORMERS[attr] + except (ValueError, KeyError): + transformer = None + else: if key: attr = key + transformer: Optional[Transformer] + try: before = elem['old_value'] except KeyError: @@ -240,15 +266,15 @@ class AuditLogChanges: self.after.expire_behaviour = self.after.expire_behavior self.before.expire_behaviour = self.before.expire_behavior - def __repr__(self): + def __repr__(self) -> str: return f'' - def _handle_role(self, first, second, entry, elem): + def _handle_role(self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, elem: List[RolePayload]) -> None: if not hasattr(first, 'roles'): setattr(first, 'roles', []) data = [] - g: Guild = entry.guild + g: Guild = entry.guild # type: ignore for e in elem: role_id = int(e['id']) @@ -263,6 +289,25 @@ class AuditLogChanges: setattr(second, 'roles', data) +class _AuditLogProxyMemberPrune: + delete_member_days: int + members_removed: int + + +class _AuditLogProxyMemberMoveOrMessageDelete: + channel: abc.GuildChannel + count: int + + +class _AuditLogProxyMemberDisconnect: + count: int + + +class _AuditLogProxyPinAction: + channel: abc.GuildChannel + message_id: int + + class AuditLogEntry(Hashable): r"""Represents an Audit Log entry. @@ -306,13 +351,13 @@ class AuditLogEntry(Hashable): which actions have this field filled out. """ - def __init__(self, *, users: Dict[str, User], data: AuditLogEntryPayload, guild: Guild): + def __init__(self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild): self._state = guild._state self.guild = guild self._users = users self._from_data(data) - def _from_data(self, data): + def _from_data(self, data: AuditLogEntryPayload) -> None: self.action = enums.try_enum(enums.AuditLogAction, data['action_type']) self.id = int(data['id']) @@ -323,31 +368,30 @@ class AuditLogEntry(Hashable): if isinstance(self.action, enums.AuditLogAction) and self.extra: if self.action is enums.AuditLogAction.member_prune: # member prune has two keys with useful information - self.extra = type('_AuditLogProxy', (), {k: int(v) for k, v in self.extra.items()})() + self.extra: _AuditLogProxyMemberPrune = type( + '_AuditLogProxy', (), {k: int(v) for k, v in self.extra.items()} + )() elif self.action is enums.AuditLogAction.member_move or self.action is enums.AuditLogAction.message_delete: channel_id = int(self.extra['channel_id']) elems = { 'count': int(self.extra['count']), 'channel': self.guild.get_channel(channel_id) or Object(id=channel_id), } - self.extra = type('_AuditLogProxy', (), elems)() + self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type('_AuditLogProxy', (), elems)() elif self.action is enums.AuditLogAction.member_disconnect: # The member disconnect action has a dict with some information elems = { 'count': int(self.extra['count']), } - self.extra = type('_AuditLogProxy', (), elems)() + self.extra: _AuditLogProxyMemberDisconnect = type('_AuditLogProxy', (), elems)() elif self.action.name.endswith('pin'): # the pin actions have a dict with some information channel_id = int(self.extra['channel_id']) - message_id = int(self.extra['message_id']) - # fmt: off elems = { 'channel': self.guild.get_channel(channel_id) or Object(id=channel_id), - 'message_id': message_id + 'message_id': int(self.extra['message_id']), } - # fmt: on - self.extra = type('_AuditLogProxy', (), elems)() + self.extra: _AuditLogProxyPinAction = type('_AuditLogProxy', (), elems)() elif self.action.name.startswith('overwrite_'): # the overwrite_ actions have a dict with some information instance_id = int(self.extra['id']) @@ -359,7 +403,18 @@ class AuditLogEntry(Hashable): if role is None: role = Object(id=instance_id) role.name = self.extra.get('role_name') # type: ignore - self.extra = role + self.extra: Role = role + + # fmt: off + self.extra: Union[ + _AuditLogProxyMemberPrune, + _AuditLogProxyMemberMoveOrMessageDelete, + _AuditLogProxyMemberDisconnect, + _AuditLogProxyPinAction, + Member, User, None, + Role, + ] + # fmt: on # this key is not present when the above is present, typically. # It's a list of { new_value: a, old_value: b, key: c } @@ -368,22 +423,22 @@ class AuditLogEntry(Hashable): # into meaningful data when requested self._changes = data.get('changes', []) - self.user = self._get_member(utils._get_as_snowflake(data, 'user_id')) + self.user = self._get_member(utils._get_as_snowflake(data, 'user_id')) # type: ignore self._target_id = utils._get_as_snowflake(data, 'target_id') - def _get_member(self, user_id): + def _get_member(self, user_id: int) -> Union[Member, User, None]: return self.guild.get_member(user_id) or self._users.get(user_id) - def __repr__(self): + def __repr__(self) -> str: return f'' @utils.cached_property - def created_at(self): + def created_at(self) -> datetime.datetime: """:class:`datetime.datetime`: Returns the entry's creation time in UTC.""" return utils.snowflake_time(self.id) @utils.cached_property - def target(self): + def target(self) -> Union[Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, Object, None]: try: converter = getattr(self, '_convert_target_' + self.action.target_type) except AttributeError: @@ -392,46 +447,40 @@ class AuditLogEntry(Hashable): return converter(self._target_id) @utils.cached_property - def category(self): + def category(self) -> enums.AuditLogActionCategory: """Optional[:class:`AuditLogActionCategory`]: The category of the action, if applicable.""" return self.action.category @utils.cached_property - def changes(self): + def changes(self) -> AuditLogChanges: """:class:`AuditLogChanges`: The list of changes this entry has.""" obj = AuditLogChanges(self, self._changes) del self._changes return obj @utils.cached_property - def before(self): + def before(self) -> AuditLogDiff: """:class:`AuditLogDiff`: The target's prior state.""" return self.changes.before @utils.cached_property - def after(self): + def after(self) -> AuditLogDiff: """:class:`AuditLogDiff`: The target's subsequent state.""" return self.changes.after - def _convert_target_guild(self, target_id): + def _convert_target_guild(self, target_id: int) -> Guild: return self.guild - def _convert_target_channel(self, target_id): - ch = self.guild.get_channel(target_id) - if ch is None: - return Object(id=target_id) - return ch + def _convert_target_channel(self, target_id: int) -> Union[abc.GuildChannel, Object]: + return self.guild.get_channel(target_id) or Object(id=target_id) - def _convert_target_user(self, target_id): + def _convert_target_user(self, target_id: int) -> Union[Member, User, None]: return self._get_member(target_id) - def _convert_target_role(self, target_id): - role = self.guild.get_role(target_id) - if role is None: - return Object(id=target_id) - return role + def _convert_target_role(self, target_id: int) -> Union[Role, Object]: + return self.guild.get_role(target_id) or Object(id=target_id) - def _convert_target_invite(self, target_id): + def _convert_target_invite(self, target_id: int) -> Invite: # invites have target_id set to null # so figure out which change has the full invite data changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after @@ -444,15 +493,15 @@ class AuditLogEntry(Hashable): 'uses': changeset.uses, } - obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) + obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) # type: ignore try: obj.inviter = changeset.inviter except AttributeError: pass return obj - def _convert_target_emoji(self, target_id): + def _convert_target_emoji(self, target_id: int) -> Union[Emoji, Object]: return self._state.get_emoji(target_id) or Object(id=target_id) - def _convert_target_message(self, target_id): + def _convert_target_message(self, target_id: int) -> Union[Member, User, None]: return self._get_member(target_id)