Browse Source

Add Object.type to Objects where a type can be determined

pull/10109/head
z03h 3 years ago
committed by dolfies
parent
commit
39a3fc341f
  1. 38
      discord/audit_logs.py
  2. 4
      discord/client.py
  3. 4
      discord/guild.py
  4. 4
      discord/invite.py
  5. 3
      discord/state.py

38
discord/audit_logs.py

@ -33,7 +33,15 @@ from .invite import Invite
from .mixins import Hashable from .mixins import Hashable
from .object import Object from .object import Object
from .permissions import PermissionOverwrite, Permissions from .permissions import PermissionOverwrite, Permissions
from .automod import AutoModTrigger, AutoModRuleAction, AutoModPresets from .automod import AutoModTrigger, AutoModRuleAction, AutoModPresets, AutoModRule
from .role import Role
from .emoji import Emoji
from .member import Member
from .scheduled_event import ScheduledEvent
from .stage_instance import StageInstance
from .sticker import GuildSticker
from .threads import Thread
from .channel import StageChannel
__all__ = ( __all__ = (
'AuditLogDiff', 'AuditLogDiff',
@ -46,11 +54,7 @@ if TYPE_CHECKING:
import datetime import datetime
from . import abc from . import abc
from .emoji import Emoji
from .guild import Guild from .guild import Guild
from .member import Member
from .role import Role
from .scheduled_event import ScheduledEvent
from .state import ConnectionState from .state import ConnectionState
from .types.audit_log import ( from .types.audit_log import (
AuditLogChange as AuditLogChangePayload, AuditLogChange as AuditLogChangePayload,
@ -123,7 +127,7 @@ def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Opti
def _transform_roles(entry: AuditLogEntry, data: List[Snowflake]) -> List[Union[Role, Object]]: def _transform_roles(entry: AuditLogEntry, data: List[Snowflake]) -> List[Union[Role, Object]]:
return [entry.guild.get_role(int(role_id)) or Object(role_id) for role_id in data] return [entry.guild.get_role(int(role_id)) or Object(role_id, type=Role) for role_id in data]
def _transform_overwrites( def _transform_overwrites(
@ -144,7 +148,7 @@ def _transform_overwrites(
target = entry._get_member(ow_id) target = entry._get_member(ow_id)
if target is None: if target is None:
target = Object(id=ow_id) target = Object(id=ow_id, type=Role if ow_type == '0' else Member)
overwrites.append((target, ow)) overwrites.append((target, ow))
@ -366,7 +370,7 @@ class AuditLogChanges:
role = g.get_role(role_id) role = g.get_role(role_id)
if role is None: if role is None:
role = Object(id=role_id) role = Object(id=role_id, type=Role)
role.name = e['name'] # type: ignore # Object doesn't usually have name role.name = e['name'] # type: ignore # Object doesn't usually have name
data.append(role) data.append(role)
@ -537,13 +541,13 @@ class AuditLogEntry(Hashable):
elif the_type == '0': elif the_type == '0':
role = self.guild.get_role(instance_id) role = self.guild.get_role(instance_id)
if role is None: if role is None:
role = Object(id=instance_id) role = Object(id=instance_id, type=Role)
role.name = extra.get('role_name') # type: ignore # Object doesn't usually have name role.name = extra.get('role_name') # type: ignore # Object doesn't usually have name
self.extra = role self.extra = role
elif self.action.name.startswith('stage_instance'): elif self.action.name.startswith('stage_instance'):
channel_id = int(extra['channel_id']) channel_id = int(extra['channel_id'])
self.extra = _AuditLogProxyStageInstanceAction( self.extra = _AuditLogProxyStageInstanceAction(
channel=self.guild.get_channel(channel_id) or Object(id=channel_id) channel=self.guild.get_channel(channel_id) or Object(id=channel_id, type=StageChannel)
) )
# this key is not present when the above is present, typically. # this key is not present when the above is present, typically.
@ -617,7 +621,7 @@ class AuditLogEntry(Hashable):
return self._get_member(target_id) return self._get_member(target_id)
def _convert_target_role(self, target_id: int) -> Union[Role, Object]: def _convert_target_role(self, target_id: int) -> Union[Role, Object]:
return self.guild.get_role(target_id) or Object(id=target_id) return self.guild.get_role(target_id) or Object(id=target_id, type=Role)
def _convert_target_invite(self, target_id: None) -> Invite: def _convert_target_invite(self, target_id: None) -> Invite:
# Invites have target_id set to null # Invites have target_id set to null
@ -641,22 +645,22 @@ class AuditLogEntry(Hashable):
return obj return obj
def _convert_target_emoji(self, target_id: int) -> Union[Emoji, Object]: def _convert_target_emoji(self, target_id: int) -> Union[Emoji, Object]:
return self._state.get_emoji(target_id) or Object(id=target_id) return self._state.get_emoji(target_id) or Object(id=target_id, type=Emoji)
def _convert_target_message(self, target_id: int) -> Union[Member, User, None]: def _convert_target_message(self, target_id: int) -> Union[Member, User, None]:
return self._get_member(target_id) return self._get_member(target_id)
def _convert_target_stage_instance(self, target_id: int) -> Union[StageInstance, Object]: def _convert_target_stage_instance(self, target_id: int) -> Union[StageInstance, Object]:
return self.guild.get_stage_instance(target_id) or Object(id=target_id) return self.guild.get_stage_instance(target_id) or Object(id=target_id, type=StageInstance)
def _convert_target_sticker(self, target_id: int) -> Union[GuildSticker, Object]: def _convert_target_sticker(self, target_id: int) -> Union[GuildSticker, Object]:
return self._state.get_sticker(target_id) or Object(id=target_id) return self._state.get_sticker(target_id) or Object(id=target_id, type=StageInstance)
def _convert_target_thread(self, target_id: int) -> Union[Thread, Object]: def _convert_target_thread(self, target_id: int) -> Union[Thread, Object]:
return self.guild.get_thread(target_id) or Object(id=target_id) return self.guild.get_thread(target_id) or Object(id=target_id, type=Thread)
def _convert_target_guild_scheduled_event(self, target_id: int) -> Union[ScheduledEvent, Object]: def _convert_target_guild_scheduled_event(self, target_id: int) -> Union[ScheduledEvent, Object]:
return self.guild.get_scheduled_event(target_id) or Object(id=target_id) return self.guild.get_scheduled_event(target_id) or Object(id=target_id, type=ScheduledEvent)
def _convert_target_auto_moderation(self, target_id: int) -> Union[AutoModRule, Object]: def _convert_target_auto_moderation(self, target_id: int) -> Union[AutoModRule, Object]:
return self._automod_rules.get(target_id) or Object(target_id) return self._automod_rules.get(target_id) or Object(target_id, type=AutoModRule)

4
discord/client.py

@ -2328,8 +2328,8 @@ class Client:
else: else:
# The factory can't be a DMChannel or GroupChannel here # The factory can't be a DMChannel or GroupChannel here
guild_id = int(data['guild_id']) # type: ignore guild_id = int(data['guild_id']) # type: ignore
guild = self.get_guild(guild_id) or Object(id=guild_id) guild = self._connection._get_or_create_unavailable_guild(guild_id)
# GuildChannels expect a Guild, we may be passing an Object # the factory should be a GuildChannel or Thread
channel = factory(guild=guild, state=self._connection, data=data) # type: ignore channel = factory(guild=guild, state=self._connection, data=data) # type: ignore
return channel return channel

4
discord/guild.py

@ -457,6 +457,10 @@ class Guild(Hashable):
return role return role
@classmethod
def _create_unavailable(cls, *, state: ConnectionState, guild_id: int) -> Guild:
return cls(state=state, data={'id': guild_id, 'unavailable': True}) # type: ignore
def _from_data(self, guild: Union[GuildPayload, PartialGuildPayload]) -> None: def _from_data(self, guild: Union[GuildPayload, PartialGuildPayload]) -> None:
try: try:
self._member_count: int = guild['member_count'] # type: ignore # Handled below self._member_count: int = guild['member_count'] # type: ignore # Handled below

4
discord/invite.py

@ -547,8 +547,8 @@ class Invite(Hashable):
if guild is not None: if guild is not None:
channel = (guild.get_channel(channel_id) or Object(id=channel_id)) if channel_id is not None else None channel = (guild.get_channel(channel_id) or Object(id=channel_id)) if channel_id is not None else None
else: else:
guild = Object(id=guild_id) if guild_id is not None else None guild = state._get_or_create_unavailable_guild(guild_id) if guild_id is not None else None
channel = Object(id=channel_id) if channel_id is not None else None channel = Object(id=channel_id)
return cls(state=state, data=data, guild=guild, channel=channel) # type: ignore return cls(state=state, data=data, guild=guild, channel=channel) # type: ignore

3
discord/state.py

@ -772,6 +772,9 @@ class ConnectionState:
guild = self._queued_guilds.get(guild_id) # type: ignore guild = self._queued_guilds.get(guild_id) # type: ignore
return guild return guild
def _get_or_create_unavailable_guild(self, guild_id: int) -> Guild:
return self._guilds.get(guild_id) or Guild._create_unavailable(state=self, guild_id=guild_id)
def _add_guild(self, guild: Guild) -> None: def _add_guild(self, guild: Guild) -> None:
self._guilds[guild.id] = guild self._guilds[guild.id] = guild

Loading…
Cancel
Save