Browse Source

Add Object.type to Objects where a type can be determined

pull/10109/head
z03h 3 years ago
committed by dolfies
parent
commit
39a3fc341f
  1. 38
      discord/audit_logs.py
  2. 4
      discord/client.py
  3. 4
      discord/guild.py
  4. 4
      discord/invite.py
  5. 3
      discord/state.py

38
discord/audit_logs.py

@ -33,7 +33,15 @@ from .invite import Invite
from .mixins import Hashable
from .object import Object
from .permissions import PermissionOverwrite, Permissions
from .automod import AutoModTrigger, AutoModRuleAction, AutoModPresets
from .automod import AutoModTrigger, AutoModRuleAction, AutoModPresets, AutoModRule
from .role import Role
from .emoji import Emoji
from .member import Member
from .scheduled_event import ScheduledEvent
from .stage_instance import StageInstance
from .sticker import GuildSticker
from .threads import Thread
from .channel import StageChannel
__all__ = (
'AuditLogDiff',
@ -46,11 +54,7 @@ if TYPE_CHECKING:
import datetime
from . import abc
from .emoji import Emoji
from .guild import Guild
from .member import Member
from .role import Role
from .scheduled_event import ScheduledEvent
from .state import ConnectionState
from .types.audit_log import (
AuditLogChange as AuditLogChangePayload,
@ -123,7 +127,7 @@ def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Opti
def _transform_roles(entry: AuditLogEntry, data: List[Snowflake]) -> List[Union[Role, Object]]:
return [entry.guild.get_role(int(role_id)) or Object(role_id) for role_id in data]
return [entry.guild.get_role(int(role_id)) or Object(role_id, type=Role) for role_id in data]
def _transform_overwrites(
@ -144,7 +148,7 @@ def _transform_overwrites(
target = entry._get_member(ow_id)
if target is None:
target = Object(id=ow_id)
target = Object(id=ow_id, type=Role if ow_type == '0' else Member)
overwrites.append((target, ow))
@ -366,7 +370,7 @@ class AuditLogChanges:
role = g.get_role(role_id)
if role is None:
role = Object(id=role_id)
role = Object(id=role_id, type=Role)
role.name = e['name'] # type: ignore # Object doesn't usually have name
data.append(role)
@ -537,13 +541,13 @@ class AuditLogEntry(Hashable):
elif the_type == '0':
role = self.guild.get_role(instance_id)
if role is None:
role = Object(id=instance_id)
role = Object(id=instance_id, type=Role)
role.name = extra.get('role_name') # type: ignore # Object doesn't usually have name
self.extra = role
elif self.action.name.startswith('stage_instance'):
channel_id = int(extra['channel_id'])
self.extra = _AuditLogProxyStageInstanceAction(
channel=self.guild.get_channel(channel_id) or Object(id=channel_id)
channel=self.guild.get_channel(channel_id) or Object(id=channel_id, type=StageChannel)
)
# this key is not present when the above is present, typically.
@ -617,7 +621,7 @@ class AuditLogEntry(Hashable):
return self._get_member(target_id)
def _convert_target_role(self, target_id: int) -> Union[Role, Object]:
return self.guild.get_role(target_id) or Object(id=target_id)
return self.guild.get_role(target_id) or Object(id=target_id, type=Role)
def _convert_target_invite(self, target_id: None) -> Invite:
# Invites have target_id set to null
@ -641,22 +645,22 @@ class AuditLogEntry(Hashable):
return obj
def _convert_target_emoji(self, target_id: int) -> Union[Emoji, Object]:
return self._state.get_emoji(target_id) or Object(id=target_id)
return self._state.get_emoji(target_id) or Object(id=target_id, type=Emoji)
def _convert_target_message(self, target_id: int) -> Union[Member, User, None]:
return self._get_member(target_id)
def _convert_target_stage_instance(self, target_id: int) -> Union[StageInstance, Object]:
return self.guild.get_stage_instance(target_id) or Object(id=target_id)
return self.guild.get_stage_instance(target_id) or Object(id=target_id, type=StageInstance)
def _convert_target_sticker(self, target_id: int) -> Union[GuildSticker, Object]:
return self._state.get_sticker(target_id) or Object(id=target_id)
return self._state.get_sticker(target_id) or Object(id=target_id, type=StageInstance)
def _convert_target_thread(self, target_id: int) -> Union[Thread, Object]:
return self.guild.get_thread(target_id) or Object(id=target_id)
return self.guild.get_thread(target_id) or Object(id=target_id, type=Thread)
def _convert_target_guild_scheduled_event(self, target_id: int) -> Union[ScheduledEvent, Object]:
return self.guild.get_scheduled_event(target_id) or Object(id=target_id)
return self.guild.get_scheduled_event(target_id) or Object(id=target_id, type=ScheduledEvent)
def _convert_target_auto_moderation(self, target_id: int) -> Union[AutoModRule, Object]:
return self._automod_rules.get(target_id) or Object(target_id)
return self._automod_rules.get(target_id) or Object(target_id, type=AutoModRule)

4
discord/client.py

@ -2328,8 +2328,8 @@ class Client:
else:
# The factory can't be a DMChannel or GroupChannel here
guild_id = int(data['guild_id']) # type: ignore
guild = self.get_guild(guild_id) or Object(id=guild_id)
# GuildChannels expect a Guild, we may be passing an Object
guild = self._connection._get_or_create_unavailable_guild(guild_id)
# the factory should be a GuildChannel or Thread
channel = factory(guild=guild, state=self._connection, data=data) # type: ignore
return channel

4
discord/guild.py

@ -457,6 +457,10 @@ class Guild(Hashable):
return role
@classmethod
def _create_unavailable(cls, *, state: ConnectionState, guild_id: int) -> Guild:
return cls(state=state, data={'id': guild_id, 'unavailable': True}) # type: ignore
def _from_data(self, guild: Union[GuildPayload, PartialGuildPayload]) -> None:
try:
self._member_count: int = guild['member_count'] # type: ignore # Handled below

4
discord/invite.py

@ -547,8 +547,8 @@ class Invite(Hashable):
if guild is not None:
channel = (guild.get_channel(channel_id) or Object(id=channel_id)) if channel_id is not None else None
else:
guild = Object(id=guild_id) if guild_id is not None else None
channel = Object(id=channel_id) if channel_id is not None else None
guild = state._get_or_create_unavailable_guild(guild_id) if guild_id is not None else None
channel = Object(id=channel_id)
return cls(state=state, data=data, guild=guild, channel=channel) # type: ignore

3
discord/state.py

@ -772,6 +772,9 @@ class ConnectionState:
guild = self._queued_guilds.get(guild_id) # type: ignore
return guild
def _get_or_create_unavailable_guild(self, guild_id: int) -> Guild:
return self._guilds.get(guild_id) or Guild._create_unavailable(state=self, guild_id=guild_id)
def _add_guild(self, guild: Guild) -> None:
self._guilds[guild.id] = guild

Loading…
Cancel
Save