Browse Source

Move guild creation to connection state to prevent circular imports

pull/10109/head
dolfies 2 years ago
parent
commit
a91802e51b
  1. 7
      discord/application.py
  2. 26
      discord/client.py
  3. 4
      discord/connections.py
  4. 5
      discord/directory.py
  5. 8
      discord/emoji.py
  6. 6
      discord/guild.py
  7. 1
      discord/invite.py
  8. 18
      discord/partial_emoji.py
  9. 15
      discord/state.py
  10. 14
      discord/sticker.py
  11. 4
      discord/store.py

7
discord/application.py

@ -1948,6 +1948,7 @@ class PartialApplication(Hashable):
self.public: bool = data.get('integration_public', data.get('bot_public', True)) self.public: bool = data.get('integration_public', data.get('bot_public', True))
self.require_code_grant: bool = data.get('integration_require_code_grant', data.get('bot_require_code_grant', False)) self.require_code_grant: bool = data.get('integration_require_code_grant', data.get('bot_require_code_grant', False))
self._has_bot: bool = 'bot_public' in data self._has_bot: bool = 'bot_public' in data
self._guild: Optional[Guild] = state.create_guild(data['guild']) if 'guild' in data else None
# Hacky, but I want these to be persisted # Hacky, but I want these to be persisted
@ -1975,12 +1976,6 @@ class PartialApplication(Hashable):
} }
self.owner = state.create_user(payload) self.owner = state.create_user(payload)
self._guild: Optional[Guild] = None
if 'guild' in data:
from .guild import Guild
self._guild = Guild(state=state, data=data['guild'])
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>' return f'<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>'

26
discord/client.py

@ -53,7 +53,7 @@ from .user import _UserTag, User, ClientUser, Note
from .invite import Invite from .invite import Invite
from .template import Template from .template import Template
from .widget import Widget from .widget import Widget
from .guild import Guild, UserGuild from .guild import UserGuild
from .emoji import Emoji from .emoji import Emoji
from .channel import _private_channel_factory, _threaded_channel_factory, GroupChannel, PartialMessageable from .channel import _private_channel_factory, _threaded_channel_factory, GroupChannel, PartialMessageable
from .enums import ActivityType, ChannelType, ClientType, ConnectionType, EntitlementType, Status from .enums import ActivityType, ChannelType, ClientType, ConnectionType, EntitlementType, Status
@ -110,6 +110,7 @@ if TYPE_CHECKING:
from .read_state import ReadState from .read_state import ReadState
from .tutorial import Tutorial from .tutorial import Tutorial
from .file import File from .file import File
from .guild import Guild
from .types.snowflake import Snowflake as _Snowflake from .types.snowflake import Snowflake as _Snowflake
PrivateChannel = Union[DMChannel, GroupChannel] PrivateChannel = Union[DMChannel, GroupChannel]
@ -1791,8 +1792,9 @@ class Client:
:class:`.Guild` :class:`.Guild`
The guild from the ID. The guild from the ID.
""" """
data = await self.http.get_guild(guild_id, with_counts) state = self._connection
guild = Guild(data=data, state=self._connection) data = await state.http.get_guild(guild_id, with_counts)
guild = state.create_guild(data)
guild._cs_joined = True guild._cs_joined = True
return guild return guild
@ -1815,8 +1817,9 @@ class Client:
:class:`.Guild` :class:`.Guild`
The guild from the ID. The guild from the ID.
""" """
data = await self.http.get_guild_preview(guild_id) state = self._connection
return Guild(data=data, state=self._connection) data = await state.http.get_guild_preview(guild_id)
return state.create_guild(data)
async def create_guild( async def create_guild(
self, self,
@ -1860,17 +1863,18 @@ class Client:
The guild created. This is not the same guild that is The guild created. This is not the same guild that is
added to cache. added to cache.
""" """
state = self._connection
if icon is not MISSING: if icon is not MISSING:
icon_base64 = utils._bytes_to_base64_data(icon) icon_base64 = utils._bytes_to_base64_data(icon)
else: else:
icon_base64 = None icon_base64 = None
if code: if code:
data = await self.http.create_from_template(code, name, icon_base64) data = await state.http.create_from_template(code, name, icon_base64)
else: else:
data = await self.http.create_guild(name, icon_base64) data = await state.http.create_guild(name, icon_base64)
guild = Guild(data=data, state=self._connection) guild = state.create_guild(data)
guild._cs_joined = True guild._cs_joined = True
return guild return guild
@ -1900,7 +1904,7 @@ class Client:
""" """
state = self._connection state = self._connection
data = await state.http.join_guild(guild_id, lurking, state.session_id) data = await state.http.join_guild(guild_id, lurking, state.session_id)
guild = Guild(data=data, state=state) guild = state.create_guild(data)
guild._cs_joined = not lurking guild._cs_joined = not lurking
return guild return guild
@ -5114,7 +5118,7 @@ class Client:
""" """
state = self._connection state = self._connection
data = await state.http.hub_lookup(email) data = await state.http.hub_lookup(email)
return [Guild(state=state, data=d) for d in data.get('guilds_info', [])] # type: ignore return [state.create_guild(d) for d in data.get('guilds_info', [])] # type: ignore
@overload @overload
async def join_hub(self, guild: Snowflake, email: str, *, code: None = ...) -> None: async def join_hub(self, guild: Snowflake, email: str, *, code: None = ...) -> None:
@ -5158,7 +5162,7 @@ class Client:
return return
data = await state.http.join_hub(email, guild.id, code) data = await state.http.join_hub(email, guild.id, code)
return Guild(state=state, data=data['guild']) return state.create_guild(data['guild'])
async def pomelo_suggestion(self) -> str: async def pomelo_suggestion(self) -> str:
"""|coro| """|coro|

4
discord/connections.py

@ -228,8 +228,6 @@ class Connection(PartialConnection):
] ]
def _resolve_guild(self, data: IntegrationPayload) -> Guild: def _resolve_guild(self, data: IntegrationPayload) -> Guild:
from .guild import Guild
state = self._state state = self._state
guild_data = data.get('guild') guild_data = data.get('guild')
if not guild_data: if not guild_data:
@ -238,7 +236,7 @@ class Connection(PartialConnection):
guild_id = int(guild_data['id']) guild_id = int(guild_data['id'])
guild = state._get_guild(guild_id) guild = state._get_guild(guild_id)
if guild is None: if guild is None:
guild = Guild(data=guild_data, state=state) guild = state.create_guild(guild_data)
return guild return guild
async def edit( async def edit(

5
discord/directory.py

@ -33,6 +33,7 @@ if TYPE_CHECKING:
from datetime import datetime from datetime import datetime
from .channel import DirectoryChannel from .channel import DirectoryChannel
from .guild import Guild
from .member import Member from .member import Member
from .state import ConnectionState from .state import ConnectionState
from .types.directory import ( from .types.directory import (
@ -118,8 +119,6 @@ class DirectoryEntry:
return NotImplemented return NotImplemented
def _update(self, data: Union[DirectoryEntryPayload, PartialDirectoryEntryPayload]): def _update(self, data: Union[DirectoryEntryPayload, PartialDirectoryEntryPayload]):
from .guild import Guild
state = self._state state = self._state
self.type: DirectoryEntryType = try_enum(DirectoryEntryType, data['type']) self.type: DirectoryEntryType = try_enum(DirectoryEntryType, data['type'])
self.category: DirectoryCategory = try_enum(DirectoryCategory, data.get('primary_category_id', 0)) self.category: DirectoryCategory = try_enum(DirectoryCategory, data.get('primary_category_id', 0))
@ -129,7 +128,7 @@ class DirectoryEntry:
self.entity_id: int = int(data['entity_id']) self.entity_id: int = int(data['entity_id'])
guild_data = data.get('guild', data.get('guild_scheduled_event', {}).get('guild')) guild_data = data.get('guild', data.get('guild_scheduled_event', {}).get('guild'))
self.guild: Optional[Guild] = Guild(data=guild_data, state=state) if guild_data is not None else None self.guild: Optional[Guild] = state.create_guild(guild_data) if guild_data is not None else None
self.featurable: bool = guild_data.get('featurable_in_directory', False) if guild_data is not None else False self.featurable: bool = guild_data.get('featurable_in_directory', False) if guild_data is not None else False
event_data = data.get('guild_scheduled_event') event_data = data.get('guild_scheduled_event')

8
discord/emoji.py

@ -256,11 +256,13 @@ class Emoji(_EmojiTag, AssetMixin):
data = await self._state.http.edit_custom_emoji(self.guild_id, self.id, payload=payload, reason=reason) data = await self._state.http.edit_custom_emoji(self.guild_id, self.id, payload=payload, reason=reason)
return Emoji(guild=self.guild, data=data, state=self._state) # type: ignore # If guild is None, the http request would have failed return Emoji(guild=self.guild, data=data, state=self._state) # type: ignore # If guild is None, the http request would have failed
async def fetch_guild(self): async def fetch_guild(self) -> Guild:
"""|coro| """|coro|
Retrieves the guild this emoji belongs to. Retrieves the guild this emoji belongs to.
.. versionadded:: 1.9
Raises Raises
------ ------
NotFound NotFound
@ -273,8 +275,6 @@ class Emoji(_EmojiTag, AssetMixin):
:class:`Guild` :class:`Guild`
The guild this emoji belongs to. The guild this emoji belongs to.
""" """
from .guild import Guild # Circular import
state = self._state state = self._state
data = await state.http.get_emoji_guild(self.id) data = await state.http.get_emoji_guild(self.id)
return Guild(state=state, data=data) return state.create_guild(data)

6
discord/guild.py

@ -103,8 +103,8 @@ from .partial_emoji import _EmojiTag, PartialEmoji
if TYPE_CHECKING: if TYPE_CHECKING:
from .abc import Snowflake, SnowflakeTime from .abc import Snowflake, SnowflakeTime
from .types.guild import ( from .types.guild import (
BaseGuild as BaseGuildPayload,
Guild as GuildPayload, Guild as GuildPayload,
PartialGuild as PartialGuildPayload,
RolePositionUpdate as RolePositionUpdatePayload, RolePositionUpdate as RolePositionUpdatePayload,
UserGuild as UserGuildPayload, UserGuild as UserGuildPayload,
) )
@ -501,7 +501,7 @@ class Guild(Hashable):
3: _GuildLimit(emoji=250, stickers=60, bitrate=384e3, filesize=104857600), 3: _GuildLimit(emoji=250, stickers=60, bitrate=384e3, filesize=104857600),
} }
def __init__(self, *, data: Union[GuildPayload, PartialGuildPayload], state: ConnectionState) -> None: def __init__(self, *, data: Union[BaseGuildPayload, GuildPayload], state: ConnectionState) -> None:
self._chunked = False self._chunked = False
self._cs_joined: Optional[bool] = None self._cs_joined: Optional[bool] = None
self._roles: Dict[int, Role] = {} self._roles: Dict[int, Role] = {}
@ -621,7 +621,7 @@ class Guild(Hashable):
def _create_unavailable(cls, *, state: ConnectionState, guild_id: int) -> Guild: def _create_unavailable(cls, *, state: ConnectionState, guild_id: int) -> Guild:
return cls(state=state, data={'id': guild_id, 'unavailable': True}) # type: ignore return cls(state=state, data={'id': guild_id, 'unavailable': True}) # type: ignore
def _from_data(self, guild: Union[GuildPayload, PartialGuildPayload]) -> None: def _from_data(self, guild: Union[BaseGuildPayload, GuildPayload]) -> None:
try: try:
self._member_count: Optional[int] = guild['member_count'] # type: ignore # Handled below self._member_count: Optional[int] = guild['member_count'] # type: ignore # Handled below
except KeyError: except KeyError:

1
discord/invite.py

@ -62,7 +62,6 @@ if TYPE_CHECKING:
PartialChannel as InviteChannelPayload, PartialChannel as InviteChannelPayload,
) )
from .state import ConnectionState from .state import ConnectionState
from .guild import Guild
from .abc import GuildChannel from .abc import GuildChannel
from .user import User from .user import User

18
discord/partial_emoji.py

@ -24,11 +24,11 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
import re import re
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from .asset import Asset, AssetMixin
from . import utils from . import utils
from .asset import Asset, AssetMixin
# fmt: off # fmt: off
__all__ = ( __all__ = (
@ -37,12 +37,14 @@ __all__ = (
# fmt: on # fmt: on
if TYPE_CHECKING: if TYPE_CHECKING:
from datetime import datetime
from typing_extensions import Self from typing_extensions import Self
from .guild import Guild
from .state import ConnectionState from .state import ConnectionState
from datetime import datetime
from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
from .types.activity import ActivityEmoji from .types.activity import ActivityEmoji
from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
class _EmojiTag: class _EmojiTag:
@ -268,11 +270,13 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return await super().read() return await super().read()
async def fetch_guild(self): async def fetch_guild(self) -> Guild:
"""|coro| """|coro|
Retrieves the guild this emoji belongs to. Retrieves the guild this emoji belongs to.
.. versionadded:: 1.9
Raises Raises
------ ------
NotFound NotFound
@ -289,8 +293,6 @@ class PartialEmoji(_EmojiTag, AssetMixin):
:class:`Guild` :class:`Guild`
The guild this emoji belongs to. The guild this emoji belongs to.
""" """
from .guild import Guild # Circular import
if self.id is None: if self.id is None:
raise ValueError('PartialEmoji is not a custom emoji') raise ValueError('PartialEmoji is not a custom emoji')
if self._state is None: if self._state is None:
@ -298,4 +300,4 @@ class PartialEmoji(_EmojiTag, AssetMixin):
state = self._state state = self._state
data = await state.http.get_emoji_guild(self.id) data = await state.http.get_emoji_guild(self.id)
return Guild(state=state, data=data) return state.create_guild(data)

15
discord/state.py

@ -121,7 +121,7 @@ if TYPE_CHECKING:
from .types.user import User as UserPayload, PartialUser as PartialUserPayload from .types.user import User as UserPayload, PartialUser as PartialUserPayload
from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
from .types.sticker import GuildSticker as GuildStickerPayload from .types.sticker import GuildSticker as GuildStickerPayload
from .types.guild import Guild as GuildPayload from .types.guild import BaseGuild as BaseGuildPayload, Guild as GuildPayload
from .types.message import ( from .types.message import (
Message as MessagePayload, Message as MessagePayload,
MessageSearchResult as MessageSearchResultPayload, MessageSearchResult as MessageSearchResultPayload,
@ -840,17 +840,17 @@ class ConnectionState:
def guilds(self) -> Sequence[Guild]: def guilds(self) -> Sequence[Guild]:
return utils.SequenceProxy(self._guilds.values()) return utils.SequenceProxy(self._guilds.values())
def _get_guild(self, guild_id: Optional[int]) -> Optional[Guild]: def _get_guild(self, guild_id: Optional[int], /) -> Optional[Guild]:
# The keys of self._guilds are ints # The keys of self._guilds are ints
return self._guilds.get(guild_id) # type: ignore return self._guilds.get(guild_id) # type: ignore
def _get_or_create_unavailable_guild(self, guild_id: int) -> Guild: def _get_or_create_unavailable_guild(self, guild_id: int, /) -> Guild:
return self._guilds.get(guild_id) or Guild._create_unavailable(state=self, guild_id=guild_id) return self._guilds.get(guild_id) or Guild._create_unavailable(state=self, guild_id=guild_id)
def _add_guild(self, guild: Guild) -> None: def _add_guild(self, guild: Guild, /) -> None:
self._guilds[guild.id] = guild self._guilds[guild.id] = guild
def _remove_guild(self, guild: Guild) -> None: def _remove_guild(self, guild: Guild, /) -> None:
self._guilds.pop(guild.id, None) self._guilds.pop(guild.id, None)
for emoji in guild.emojis: for emoji in guild.emojis:
@ -861,6 +861,9 @@ class ConnectionState:
del guild del guild
def create_guild(self, guild: BaseGuildPayload, /) -> Guild:
return Guild(data=guild, state=self)
@property @property
def emojis(self) -> Sequence[Emoji]: def emojis(self) -> Sequence[Emoji]:
return utils.SequenceProxy(self._emojis.values()) return utils.SequenceProxy(self._emojis.values())
@ -923,7 +926,7 @@ class ConnectionState:
) )
def _add_guild_from_data(self, data: GuildPayload) -> Guild: def _add_guild_from_data(self, data: GuildPayload) -> Guild:
guild = Guild(data=data, state=self) guild = self.create_guild(data)
self._add_guild(guild) self._add_guild(guild)
return guild return guild

14
discord/sticker.py

@ -229,11 +229,13 @@ class StickerItem(_StickerTag):
cls, _ = _sticker_factory(data['type']) cls, _ = _sticker_factory(data['type'])
return cls(state=self._state, data=data) return cls(state=self._state, data=data)
async def fetch_guild(self): async def fetch_guild(self) -> Guild:
"""|coro| """|coro|
Retrieves the guild this sticker belongs to. Retrieves the guild this sticker belongs to.
.. versionadded:: 1.9
Raises Raises
------ ------
NotFound NotFound
@ -246,11 +248,9 @@ class StickerItem(_StickerTag):
:class:`Guild` :class:`Guild`
The guild this emoji belongs to. The guild this emoji belongs to.
""" """
from .guild import Guild # Circular import
state = self._state state = self._state
data = await state.http.get_sticker_guild(self.id) data = await state.http.get_sticker_guild(self.id)
return Guild(state=state, data=data) return state.create_guild(data)
class Sticker(_StickerTag): class Sticker(_StickerTag):
@ -532,7 +532,7 @@ class GuildSticker(Sticker):
""" """
await self._state.http.delete_guild_sticker(self.guild_id, self.id, reason) await self._state.http.delete_guild_sticker(self.guild_id, self.id, reason)
async def fetch_guild(self): async def fetch_guild(self) -> Guild:
"""|coro| """|coro|
Retrieves the guild this sticker belongs to. Retrieves the guild this sticker belongs to.
@ -549,11 +549,9 @@ class GuildSticker(Sticker):
:class:`Guild` :class:`Guild`
The guild this emoji belongs to. The guild this emoji belongs to.
""" """
from .guild import Guild # Circular import
state = self._state state = self._state
data = await state.http.get_sticker_guild(self.id) data = await state.http.get_sticker_guild(self.id)
return Guild(state=state, data=data) return state.create_guild(data)
def _sticker_factory(sticker_type: Literal[1, 2]) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]: def _sticker_factory(sticker_type: Literal[1, 2]) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]:

4
discord/store.py

@ -679,8 +679,6 @@ class StoreListing(Hashable):
return f'<StoreListing id={self.id} summary={self.summary!r} sku={self.sku!r}>' return f'<StoreListing id={self.id} summary={self.summary!r} sku={self.sku!r}>'
def _update(self, data: StoreListingPayload, application: Optional[PartialApplication] = None) -> None: def _update(self, data: StoreListingPayload, application: Optional[PartialApplication] = None) -> None:
from .guild import Guild
state = self._state state = self._state
self.summary, self.summary_localizations = _parse_localizations(data, 'summary') self.summary, self.summary_localizations = _parse_localizations(data, 'summary')
@ -693,7 +691,7 @@ class StoreListing(Hashable):
self.child_skus: List[SKU] = [SKU(data=sku, state=state) for sku in data.get('child_skus', [])] self.child_skus: List[SKU] = [SKU(data=sku, state=state) for sku in data.get('child_skus', [])]
self.alternative_skus: List[SKU] = [SKU(data=sku, state=state) for sku in data.get('alternative_skus', [])] self.alternative_skus: List[SKU] = [SKU(data=sku, state=state) for sku in data.get('alternative_skus', [])]
self.entitlement_branch_id: Optional[int] = _get_as_snowflake(data, 'entitlement_branch_id') self.entitlement_branch_id: Optional[int] = _get_as_snowflake(data, 'entitlement_branch_id')
self.guild: Optional[Guild] = Guild(data=data['guild'], state=state) if 'guild' in data else None self.guild: Optional[Guild] = state.create_guild(data['guild']) if 'guild' in data else None
self.published: bool = data.get('published', True) self.published: bool = data.get('published', True)
self.staff_note: Optional[StoreNote] = ( self.staff_note: Optional[StoreNote] = (
StoreNote(data=data['staff_notes'], state=state) if 'staff_notes' in data else None StoreNote(data=data['staff_notes'], state=state) if 'staff_notes' in data else None

Loading…
Cancel
Save