diff --git a/discord/abc.py b/discord/abc.py index da1636f27..2b44fa306 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -1106,14 +1106,8 @@ class GuildChannel: state = self._state data = await state.http.invites_from_channel(self.id) - result = [] - - for invite in data: - invite['channel'] = self - invite['guild'] = self.guild - result.append(Invite(state=state, data=invite)) - - return result + guild = self.guild + return [Invite(state=state, data=invite, channel=self, guild=guild) for invite in data] class Messageable(Protocol): diff --git a/discord/audit_logs.py b/discord/audit_logs.py index d05c8a672..f6a931509 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -441,12 +441,10 @@ class AuditLogEntry(Hashable): 'max_uses': changeset.max_uses, 'code': changeset.code, 'temporary': changeset.temporary, - 'channel': changeset.channel, 'uses': changeset.uses, - 'guild': self.guild, } - obj = Invite(state=self._state, data=fake_payload) + obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) try: obj.inviter = changeset.inviter except AttributeError: diff --git a/discord/guild.py b/discord/guild.py index a08aafb53..e7c74de92 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -1718,9 +1718,7 @@ class Guild(Hashable): result = [] for invite in data: channel = self.get_channel(int(invite['channel']['id'])) - invite['channel'] = channel - invite['guild'] = self - result.append(Invite(state=self._state, data=invite)) + result.append(Invite(state=self._state, data=invite, guild=self, channel=channel)) return result @@ -2219,13 +2217,12 @@ class Guild(Hashable): # reliable or a thing anymore data = await self._state.http.get_invite(payload['code']) - payload['guild'] = self - payload['channel'] = self.get_channel(int(data['channel']['id'])) + channel = self.get_channel(int(data['channel']['id'])) payload['revoked'] = False payload['temporary'] = False payload['max_uses'] = 0 payload['max_age'] = 0 - return Invite(state=self._state, data=payload) + return Invite(state=self._state, data=payload, guild=self, channel=channel) def audit_logs( self, diff --git a/discord/invite.py b/discord/invite.py index 8e60d1e8b..7c0ac6b49 100644 --- a/discord/invite.py +++ b/discord/invite.py @@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Optional, TYPE_CHECKING +from typing import List, Optional, Type, TypeVar, Union, TYPE_CHECKING from .asset import Asset from .utils import parse_time, snowflake_time, _get_as_snowflake from .object import Object @@ -42,10 +42,19 @@ if TYPE_CHECKING: from .types.invite import ( Invite as InvitePayload, InviteGuild as InviteGuildPayload, + GatewayInvite as GatewayInvitePayload, ) from .types.channel import ( - PartialChannel as PartialChannelPayload, + PartialChannel as InviteChannelPayload, ) + from .state import ConnectionState + from .guild import Guild + from .abc import GuildChannel + from .user import User + + InviteGuildType = Union[Guild, 'PartialInviteGuild', Object] + InviteChannelType = Union[GuildChannel, 'PartialInviteChannel', Object] + import datetime @@ -85,24 +94,24 @@ class PartialInviteChannel: __slots__ = ('id', 'name', 'type') - def __init__(self, **kwargs): - self.id = int(kwargs.pop('id')) - self.name = kwargs.pop('name') - self.type = kwargs.pop('type') + def __init__(self, data: InviteChannelPayload): + self.id: int = int(data['id']) + self.name: str = data['name'] + self.type: ChannelType = try_enum(ChannelType, data['type']) - def __str__(self): + def __str__(self) -> str: return self.name - def __repr__(self): + def __repr__(self) -> str: return f'' @property - def mention(self): + def mention(self) -> str: """:class:`str`: The string that allows you to mention the channel.""" return f'<#{self.id}>' @property - def created_at(self): + def created_at(self) -> datetime.datetime: """:class:`datetime.datetime`: Returns the channel's creation time in UTC.""" return snowflake_time(self.id) @@ -147,16 +156,16 @@ class PartialInviteGuild: __slots__ = ('_state', 'features', '_icon', '_banner', 'id', 'name', '_splash', 'verification_level', 'description') - def __init__(self, state, data: InviteGuildPayload, id: int): - self._state = state - self.id = id - self.name = data['name'] - self.features = data.get('features', []) - self._icon = data.get('icon') - self._banner = data.get('banner') - self._splash = data.get('splash') - self.verification_level = try_enum(VerificationLevel, data.get('verification_level')) - self.description = data.get('description') + def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: int): + self._state: ConnectionState = state + self.id: int = id + self.name: str = data['name'] + self.features: List[str] = data.get('features', []) + self._icon: Optional[str] = data.get('icon') + self._banner: Optional[str] = data.get('banner') + self._splash: Optional[str] = data.get('splash') + self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get('verification_level')) + self.description: Optional[str] = data.get('description') def __str__(self) -> str: return self.name @@ -194,6 +203,9 @@ class PartialInviteGuild: return Asset._from_guild_image(self._state, self.id, self._splash, path='splashes') +I = TypeVar('I', bound='Invite') + + class Invite(Hashable): r"""Represents a Discord :class:`Guild` or :class:`abc.GuildChannel` invite. @@ -280,9 +292,9 @@ class Invite(Hashable): The channel the invite is for. target_type: :class:`InviteTarget` The type of target for the voice channel invite. - + .. versionadded:: 2.0 - + target_user: Optional[:class:`User`] The user whose stream to display for this invite, if any. @@ -316,74 +328,107 @@ class Invite(Hashable): BASE = 'https://discord.gg' - def __init__(self, *, state, data: InvitePayload): - self._state = state - self.max_age = data.get('max_age') - self.code = data['code'] - self.guild = data.get('guild') - self.revoked = data.get('revoked') - self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at')) # type: ignore - self.temporary = data.get('temporary') - self.uses = data.get('uses') - self.max_uses = data.get('max_uses') - self.approximate_presence_count = data.get('approximate_presence_count') - self.approximate_member_count = data.get('approximate_member_count') + def __init__( + self, + *, + state: ConnectionState, + data: InvitePayload, + guild: Optional[Union[PartialInviteGuild, Guild]] = None, + channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None, + ): + self._state: ConnectionState = state + self.max_age: Optional[int] = data.get('max_age') + self.code: str = data['code'] + self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get('guild'), guild) + self.revoked: Optional[bool] = data.get('revoked') + self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at')) + self.temporary: Optional[bool] = data.get('temporary') + self.uses: Optional[int] = data.get('uses') + self.max_uses: Optional[int] = data.get('max_uses') + self.approximate_presence_count: Optional[int] = data.get('approximate_presence_count') + self.approximate_member_count: Optional[int] = data.get('approximate_member_count') + expires_at = data.get('expires_at', None) - self.expires_at = parse_time(expires_at) if expires_at else None + self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None inviter_data = data.get('inviter') - self.inviter = None if inviter_data is None else self._state.store_user(inviter_data) - self.channel = data.get('channel') + self.inviter: Optional[User] = None if inviter_data is None else self._state.store_user(inviter_data) + + self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get('channel'), channel) + target_user_data = data.get('target_user') - self.target_user = None if target_user_data is None else self._state.store_user(target_user_data) + self.target_user: Optional[User] = None if target_user_data is None else self._state.store_user(target_user_data) - self.target_type = try_enum(InviteTarget, data.get("target_type", 0)) + self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0)) application = data.get('target_application') - self.target_application = PartialAppInfo(data=application, state=state) if application else None + self.target_application: Optional[PartialAppInfo] = ( + PartialAppInfo(data=application, state=state) if application else None + ) @classmethod - def from_incomplete(cls, *, state, data): + def from_incomplete(cls: Type[I], *, state: ConnectionState, data: InvitePayload) -> I: + guild: Optional[Union[Guild, PartialInviteGuild]] try: - guild_id = int(data['guild']['id']) + guild_data = data['guild'] except KeyError: # If we're here, then this is a group DM guild = None else: + guild_id = int(guild_data['id']) guild = state._get_guild(guild_id) if guild is None: # If it's not cached, then it has to be a partial guild - guild_data = data['guild'] guild = PartialInviteGuild(state, guild_data, guild_id) # As far as I know, invites always need a channel # So this should never raise. - channel_data: PartialChannelPayload = data['channel'] - channel_id = int(channel_data['id']) - channel_type = try_enum(ChannelType, channel_data['type']) - channel = PartialInviteChannel(id=channel_id, name=channel_data['name'], type=channel_type) + channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data['channel']) if guild is not None and not isinstance(guild, PartialInviteGuild): # Upgrade the partial data if applicable - channel = guild.get_channel(channel_id) or channel + channel = guild.get_channel(channel.id) or channel - data['guild'] = guild - data['channel'] = channel - return cls(state=state, data=data) + return cls(state=state, data=data, guild=guild, channel=channel) @classmethod - def from_gateway(cls, *, state, data): - guild_id = _get_as_snowflake(data, 'guild_id') - guild = state._get_guild(guild_id) - channel_id = _get_as_snowflake(data, 'channel_id') + def from_gateway(cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I: + guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id') + guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id) + channel_id = int(data['channel_id']) if guild is not None: - channel = guild.get_channel(channel_id) or Object(id=channel_id) + channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore else: - guild = Object(id=guild_id) + guild = Object(id=guild_id) if guild_id is not None else None channel = Object(id=channel_id) - data['guild'] = guild - data['channel'] = channel - return cls(state=state, data=data) + return cls(state=state, data=data, guild=guild, channel=channel) # type: ignore + + def _resolve_guild( + self, + data: Optional[InviteGuildPayload], + guild: Optional[Union[Guild, PartialInviteGuild]] = None, + ) -> Optional[InviteGuildType]: + if guild is not None: + return guild + + if data is None: + return None + + guild_id = int(data['id']) + return PartialInviteGuild(self._state, data, guild_id) + + def _resolve_channel( + self, + data: Optional[InviteChannelPayload], + channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None, + ) -> Optional[InviteChannelType]: + if channel is not None: + return channel + + if data is None: + return None + + return PartialInviteChannel(data) def __str__(self) -> str: return self.url diff --git a/discord/types/invite.py b/discord/types/invite.py index 3760f0fab..faf4b73aa 100644 --- a/discord/types/invite.py +++ b/discord/types/invite.py @@ -24,8 +24,9 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Literal, TypedDict +from typing import Literal, Optional, TypedDict, Union +from .snowflake import Snowflake from .guild import InviteGuild, _GuildPreviewUnique from .channel import PartialChannel from .user import PartialUser @@ -45,6 +46,7 @@ class _InviteOptional(TypedDict, total=False): class _InviteMetadata(TypedDict, total=False): uses: int max_uses: int + max_age: int temporary: bool created_at: str @@ -60,3 +62,33 @@ class Invite(IncompleteInvite, _InviteOptional): class InviteWithCounts(Invite, _GuildPreviewUnique): ... + + +class _GatewayInviteCreateOptional(TypedDict, total=False): + guild_id: Snowflake + inviter: PartialUser + target_type: InviteTargetType + target_user: PartialUser + target_application: PartialAppInfo + + +class GatewayInviteCreate(_GatewayInviteCreateOptional): + channel_id: Snowflake + code: str + created_at: str + max_age: int + max_uses: int + temporary: bool + uses: bool + + +class _GatewayInviteDeleteOptional(TypedDict, total=False): + guild_id: Snowflake + + +class GatewayInviteDelete(_GatewayInviteDeleteOptional): + channel_id: Snowflake + code: str + + +GatewayInvite = Union[GatewayInviteCreate, GatewayInviteDelete]