Browse Source

Move guild creation to connection state to prevent circular imports

pull/10109/head
dolfies 2 years ago
parent
commit
a91802e51b
  1. 7
      discord/application.py
  2. 26
      discord/client.py
  3. 4
      discord/connections.py
  4. 5
      discord/directory.py
  5. 8
      discord/emoji.py
  6. 6
      discord/guild.py
  7. 1
      discord/invite.py
  8. 18
      discord/partial_emoji.py
  9. 15
      discord/state.py
  10. 14
      discord/sticker.py
  11. 4
      discord/store.py

7
discord/application.py

@ -1948,6 +1948,7 @@ class PartialApplication(Hashable):
self.public: bool = data.get('integration_public', data.get('bot_public', True))
self.require_code_grant: bool = data.get('integration_require_code_grant', data.get('bot_require_code_grant', False))
self._has_bot: bool = 'bot_public' in data
self._guild: Optional[Guild] = state.create_guild(data['guild']) if 'guild' in data else None
# Hacky, but I want these to be persisted
@ -1975,12 +1976,6 @@ class PartialApplication(Hashable):
}
self.owner = state.create_user(payload)
self._guild: Optional[Guild] = None
if 'guild' in data:
from .guild import Guild
self._guild = Guild(state=state, data=data['guild'])
def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>'

26
discord/client.py

@ -53,7 +53,7 @@ from .user import _UserTag, User, ClientUser, Note
from .invite import Invite
from .template import Template
from .widget import Widget
from .guild import Guild, UserGuild
from .guild import UserGuild
from .emoji import Emoji
from .channel import _private_channel_factory, _threaded_channel_factory, GroupChannel, PartialMessageable
from .enums import ActivityType, ChannelType, ClientType, ConnectionType, EntitlementType, Status
@ -110,6 +110,7 @@ if TYPE_CHECKING:
from .read_state import ReadState
from .tutorial import Tutorial
from .file import File
from .guild import Guild
from .types.snowflake import Snowflake as _Snowflake
PrivateChannel = Union[DMChannel, GroupChannel]
@ -1791,8 +1792,9 @@ class Client:
:class:`.Guild`
The guild from the ID.
"""
data = await self.http.get_guild(guild_id, with_counts)
guild = Guild(data=data, state=self._connection)
state = self._connection
data = await state.http.get_guild(guild_id, with_counts)
guild = state.create_guild(data)
guild._cs_joined = True
return guild
@ -1815,8 +1817,9 @@ class Client:
:class:`.Guild`
The guild from the ID.
"""
data = await self.http.get_guild_preview(guild_id)
return Guild(data=data, state=self._connection)
state = self._connection
data = await state.http.get_guild_preview(guild_id)
return state.create_guild(data)
async def create_guild(
self,
@ -1860,17 +1863,18 @@ class Client:
The guild created. This is not the same guild that is
added to cache.
"""
state = self._connection
if icon is not MISSING:
icon_base64 = utils._bytes_to_base64_data(icon)
else:
icon_base64 = None
if code:
data = await self.http.create_from_template(code, name, icon_base64)
data = await state.http.create_from_template(code, name, icon_base64)
else:
data = await self.http.create_guild(name, icon_base64)
data = await state.http.create_guild(name, icon_base64)
guild = Guild(data=data, state=self._connection)
guild = state.create_guild(data)
guild._cs_joined = True
return guild
@ -1900,7 +1904,7 @@ class Client:
"""
state = self._connection
data = await state.http.join_guild(guild_id, lurking, state.session_id)
guild = Guild(data=data, state=state)
guild = state.create_guild(data)
guild._cs_joined = not lurking
return guild
@ -5114,7 +5118,7 @@ class Client:
"""
state = self._connection
data = await state.http.hub_lookup(email)
return [Guild(state=state, data=d) for d in data.get('guilds_info', [])] # type: ignore
return [state.create_guild(d) for d in data.get('guilds_info', [])] # type: ignore
@overload
async def join_hub(self, guild: Snowflake, email: str, *, code: None = ...) -> None:
@ -5158,7 +5162,7 @@ class Client:
return
data = await state.http.join_hub(email, guild.id, code)
return Guild(state=state, data=data['guild'])
return state.create_guild(data['guild'])
async def pomelo_suggestion(self) -> str:
"""|coro|

4
discord/connections.py

@ -228,8 +228,6 @@ class Connection(PartialConnection):
]
def _resolve_guild(self, data: IntegrationPayload) -> Guild:
from .guild import Guild
state = self._state
guild_data = data.get('guild')
if not guild_data:
@ -238,7 +236,7 @@ class Connection(PartialConnection):
guild_id = int(guild_data['id'])
guild = state._get_guild(guild_id)
if guild is None:
guild = Guild(data=guild_data, state=state)
guild = state.create_guild(guild_data)
return guild
async def edit(

5
discord/directory.py

@ -33,6 +33,7 @@ if TYPE_CHECKING:
from datetime import datetime
from .channel import DirectoryChannel
from .guild import Guild
from .member import Member
from .state import ConnectionState
from .types.directory import (
@ -118,8 +119,6 @@ class DirectoryEntry:
return NotImplemented
def _update(self, data: Union[DirectoryEntryPayload, PartialDirectoryEntryPayload]):
from .guild import Guild
state = self._state
self.type: DirectoryEntryType = try_enum(DirectoryEntryType, data['type'])
self.category: DirectoryCategory = try_enum(DirectoryCategory, data.get('primary_category_id', 0))
@ -129,7 +128,7 @@ class DirectoryEntry:
self.entity_id: int = int(data['entity_id'])
guild_data = data.get('guild', data.get('guild_scheduled_event', {}).get('guild'))
self.guild: Optional[Guild] = Guild(data=guild_data, state=state) if guild_data is not None else None
self.guild: Optional[Guild] = state.create_guild(guild_data) if guild_data is not None else None
self.featurable: bool = guild_data.get('featurable_in_directory', False) if guild_data is not None else False
event_data = data.get('guild_scheduled_event')

8
discord/emoji.py

@ -256,11 +256,13 @@ class Emoji(_EmojiTag, AssetMixin):
data = await self._state.http.edit_custom_emoji(self.guild_id, self.id, payload=payload, reason=reason)
return Emoji(guild=self.guild, data=data, state=self._state) # type: ignore # If guild is None, the http request would have failed
async def fetch_guild(self):
async def fetch_guild(self) -> Guild:
"""|coro|
Retrieves the guild this emoji belongs to.
.. versionadded:: 1.9
Raises
------
NotFound
@ -273,8 +275,6 @@ class Emoji(_EmojiTag, AssetMixin):
:class:`Guild`
The guild this emoji belongs to.
"""
from .guild import Guild # Circular import
state = self._state
data = await state.http.get_emoji_guild(self.id)
return Guild(state=state, data=data)
return state.create_guild(data)

6
discord/guild.py

@ -103,8 +103,8 @@ from .partial_emoji import _EmojiTag, PartialEmoji
if TYPE_CHECKING:
from .abc import Snowflake, SnowflakeTime
from .types.guild import (
BaseGuild as BaseGuildPayload,
Guild as GuildPayload,
PartialGuild as PartialGuildPayload,
RolePositionUpdate as RolePositionUpdatePayload,
UserGuild as UserGuildPayload,
)
@ -501,7 +501,7 @@ class Guild(Hashable):
3: _GuildLimit(emoji=250, stickers=60, bitrate=384e3, filesize=104857600),
}
def __init__(self, *, data: Union[GuildPayload, PartialGuildPayload], state: ConnectionState) -> None:
def __init__(self, *, data: Union[BaseGuildPayload, GuildPayload], state: ConnectionState) -> None:
self._chunked = False
self._cs_joined: Optional[bool] = None
self._roles: Dict[int, Role] = {}
@ -621,7 +621,7 @@ class Guild(Hashable):
def _create_unavailable(cls, *, state: ConnectionState, guild_id: int) -> Guild:
return cls(state=state, data={'id': guild_id, 'unavailable': True}) # type: ignore
def _from_data(self, guild: Union[GuildPayload, PartialGuildPayload]) -> None:
def _from_data(self, guild: Union[BaseGuildPayload, GuildPayload]) -> None:
try:
self._member_count: Optional[int] = guild['member_count'] # type: ignore # Handled below
except KeyError:

1
discord/invite.py

@ -62,7 +62,6 @@ if TYPE_CHECKING:
PartialChannel as InviteChannelPayload,
)
from .state import ConnectionState
from .guild import Guild
from .abc import GuildChannel
from .user import User

18
discord/partial_emoji.py

@ -24,11 +24,11 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
import re
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from .asset import Asset, AssetMixin
from . import utils
from .asset import Asset, AssetMixin
# fmt: off
__all__ = (
@ -37,12 +37,14 @@ __all__ = (
# fmt: on
if TYPE_CHECKING:
from datetime import datetime
from typing_extensions import Self
from .guild import Guild
from .state import ConnectionState
from datetime import datetime
from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
from .types.activity import ActivityEmoji
from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
class _EmojiTag:
@ -268,11 +270,13 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return await super().read()
async def fetch_guild(self):
async def fetch_guild(self) -> Guild:
"""|coro|
Retrieves the guild this emoji belongs to.
.. versionadded:: 1.9
Raises
------
NotFound
@ -289,8 +293,6 @@ class PartialEmoji(_EmojiTag, AssetMixin):
:class:`Guild`
The guild this emoji belongs to.
"""
from .guild import Guild # Circular import
if self.id is None:
raise ValueError('PartialEmoji is not a custom emoji')
if self._state is None:
@ -298,4 +300,4 @@ class PartialEmoji(_EmojiTag, AssetMixin):
state = self._state
data = await state.http.get_emoji_guild(self.id)
return Guild(state=state, data=data)
return state.create_guild(data)

15
discord/state.py

@ -121,7 +121,7 @@ if TYPE_CHECKING:
from .types.user import User as UserPayload, PartialUser as PartialUserPayload
from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
from .types.sticker import GuildSticker as GuildStickerPayload
from .types.guild import Guild as GuildPayload
from .types.guild import BaseGuild as BaseGuildPayload, Guild as GuildPayload
from .types.message import (
Message as MessagePayload,
MessageSearchResult as MessageSearchResultPayload,
@ -840,17 +840,17 @@ class ConnectionState:
def guilds(self) -> Sequence[Guild]:
return utils.SequenceProxy(self._guilds.values())
def _get_guild(self, guild_id: Optional[int]) -> Optional[Guild]:
def _get_guild(self, guild_id: Optional[int], /) -> Optional[Guild]:
# The keys of self._guilds are ints
return self._guilds.get(guild_id) # type: ignore
def _get_or_create_unavailable_guild(self, guild_id: int) -> Guild:
def _get_or_create_unavailable_guild(self, guild_id: int, /) -> Guild:
return self._guilds.get(guild_id) or Guild._create_unavailable(state=self, guild_id=guild_id)
def _add_guild(self, guild: Guild) -> None:
def _add_guild(self, guild: Guild, /) -> None:
self._guilds[guild.id] = guild
def _remove_guild(self, guild: Guild) -> None:
def _remove_guild(self, guild: Guild, /) -> None:
self._guilds.pop(guild.id, None)
for emoji in guild.emojis:
@ -861,6 +861,9 @@ class ConnectionState:
del guild
def create_guild(self, guild: BaseGuildPayload, /) -> Guild:
return Guild(data=guild, state=self)
@property
def emojis(self) -> Sequence[Emoji]:
return utils.SequenceProxy(self._emojis.values())
@ -923,7 +926,7 @@ class ConnectionState:
)
def _add_guild_from_data(self, data: GuildPayload) -> Guild:
guild = Guild(data=data, state=self)
guild = self.create_guild(data)
self._add_guild(guild)
return guild

14
discord/sticker.py

@ -229,11 +229,13 @@ class StickerItem(_StickerTag):
cls, _ = _sticker_factory(data['type'])
return cls(state=self._state, data=data)
async def fetch_guild(self):
async def fetch_guild(self) -> Guild:
"""|coro|
Retrieves the guild this sticker belongs to.
.. versionadded:: 1.9
Raises
------
NotFound
@ -246,11 +248,9 @@ class StickerItem(_StickerTag):
:class:`Guild`
The guild this emoji belongs to.
"""
from .guild import Guild # Circular import
state = self._state
data = await state.http.get_sticker_guild(self.id)
return Guild(state=state, data=data)
return state.create_guild(data)
class Sticker(_StickerTag):
@ -532,7 +532,7 @@ class GuildSticker(Sticker):
"""
await self._state.http.delete_guild_sticker(self.guild_id, self.id, reason)
async def fetch_guild(self):
async def fetch_guild(self) -> Guild:
"""|coro|
Retrieves the guild this sticker belongs to.
@ -549,11 +549,9 @@ class GuildSticker(Sticker):
:class:`Guild`
The guild this emoji belongs to.
"""
from .guild import Guild # Circular import
state = self._state
data = await state.http.get_sticker_guild(self.id)
return Guild(state=state, data=data)
return state.create_guild(data)
def _sticker_factory(sticker_type: Literal[1, 2]) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]:

4
discord/store.py

@ -679,8 +679,6 @@ class StoreListing(Hashable):
return f'<StoreListing id={self.id} summary={self.summary!r} sku={self.sku!r}>'
def _update(self, data: StoreListingPayload, application: Optional[PartialApplication] = None) -> None:
from .guild import Guild
state = self._state
self.summary, self.summary_localizations = _parse_localizations(data, 'summary')
@ -693,7 +691,7 @@ class StoreListing(Hashable):
self.child_skus: List[SKU] = [SKU(data=sku, state=state) for sku in data.get('child_skus', [])]
self.alternative_skus: List[SKU] = [SKU(data=sku, state=state) for sku in data.get('alternative_skus', [])]
self.entitlement_branch_id: Optional[int] = _get_as_snowflake(data, 'entitlement_branch_id')
self.guild: Optional[Guild] = Guild(data=data['guild'], state=state) if 'guild' in data else None
self.guild: Optional[Guild] = state.create_guild(data['guild']) if 'guild' in data else None
self.published: bool = data.get('published', True)
self.staff_note: Optional[StoreNote] = (
StoreNote(data=data['staff_notes'], state=state) if 'staff_notes' in data else None

Loading…
Cancel
Save