From 0e0ff384f7dca5e325fec10a463b9f7e4f6c4758 Mon Sep 17 00:00:00 2001 From: dolfies Date: Sat, 25 Dec 2021 13:13:43 -0500 Subject: [PATCH] Fix group/friend invites --- discord/invite.py | 153 +++++++++++++++++++++++++++++++++------------- 1 file changed, 109 insertions(+), 44 deletions(-) diff --git a/discord/invite.py b/discord/invite.py index 17c0d97e3..12c6290bc 100644 --- a/discord/invite.py +++ b/discord/invite.py @@ -26,11 +26,11 @@ from __future__ import annotations from typing import List, Optional, Type, TypeVar, Union, TYPE_CHECKING from .asset import Asset -from .utils import parse_time, snowflake_time, _get_as_snowflake +from .utils import parse_time, snowflake_time, _get_as_snowflake, MISSING from .object import Object from .mixins import Hashable -from .enums import ChannelType, VerificationLevel, InviteTarget, try_enum -from .appinfo import PartialAppInfo +from .enums import ChannelType, VerificationLevel, InviteTarget, InviteType, try_enum +from .welcome_screen import WelcomeScreen __all__ = ( 'PartialInviteChannel', @@ -49,11 +49,14 @@ if TYPE_CHECKING: ) from .state import ConnectionState from .guild import Guild - from .abc import GuildChannel + from .abc import GuildChannel, PrivateChannel from .user import User + from .appinfo import PartialApplication + from .message import Message + from .channel import GroupChannel InviteGuildType = Union[Guild, 'PartialInviteGuild', Object] - InviteChannelType = Union[GuildChannel, 'PartialInviteChannel', Object] + InviteChannelType = Union[GuildChannel, 'PartialInviteChannel', Object, PrivateChannel] import datetime @@ -94,7 +97,14 @@ class PartialInviteChannel: __slots__ = ('id', 'name', 'type') - def __init__(self, data: InviteChannelPayload): + def __new__(cls, data: Optional[InviteChannelPayload]): + if data is None: + return + return super().__new__(cls) + + def __init__(self, data: Optional[InviteChannelPayload]): + if data is None: + return self.id: int = int(data['id']) self.name: str = data['name'] self.type: ChannelType = try_enum(ChannelType, data['type']) @@ -261,8 +271,13 @@ class Invite(Hashable): A value of ``0`` indicates that it doesn't expire. code: :class:`str` The URL fragment used for the invite. + type: :class:`InviteType` + The type of invite. + + .. versionadded:: 2.0 + guild: Optional[Union[:class:`Guild`, :class:`Object`, :class:`PartialInviteGuild`]] - The guild the invite is for. Can be ``None`` if it's from a group direct message. + The guild the invite is for. Can be ``None`` if not a guild invite. revoked: :class:`bool` Indicates if the invite has been revoked. created_at: :class:`datetime.datetime` @@ -288,8 +303,8 @@ class Invite(Hashable): .. versionadded:: 2.0 - channel: Union[:class:`abc.GuildChannel`, :class:`Object`, :class:`PartialInviteChannel`] - The channel the invite is for. + channel: Optional[Union[:class:`abc.GuildChannel`, :class:`Object`, :class:`PartialInviteChannel`]] + The channel the invite is for. Can be ``None`` if not a guild invite. target_type: :class:`InviteTarget` The type of target for the voice channel invite. @@ -300,9 +315,14 @@ class Invite(Hashable): .. versionadded:: 2.0 - target_application: Optional[:class:`PartialAppInfo`] + target_application: Optional[:class:`PartialApplication`] The embedded application the invite targets, if any. + .. versionadded:: 2.0 + + welcome_screen: Optional[:class:`WelcomeScreen`] + The guild's welcome screen, if available. + .. versionadded:: 2.0 """ @@ -324,7 +344,9 @@ class Invite(Hashable): 'approximate_presence_count', 'target_application', 'expires_at', - '_message_id', + '_message', + 'welcome_screen', + 'type', ) BASE = 'https://discord.gg' @@ -336,8 +358,10 @@ class Invite(Hashable): data: InvitePayload, guild: Optional[Union[PartialInviteGuild, Guild]] = None, channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None, + welcome_screen: Optional[WelcomeScreen] = None, ): self._state: ConnectionState = state + self.type: InviteType = try_enum(InviteType, data.get('type', 0)) self.max_age: Optional[int] = data.get('max_age') self.code: str = data['code'] self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get('guild'), guild) @@ -348,7 +372,7 @@ class Invite(Hashable): self.max_uses: Optional[int] = data.get('max_uses') self.approximate_presence_count: Optional[int] = data.get('approximate_presence_count') self.approximate_member_count: Optional[int] = data.get('approximate_member_count') - self._message_id: Optional[int] = data.get('message_id') + self._message: Optional[Message] = data.get('message') expires_at = data.get('expires_at', None) self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None @@ -364,13 +388,16 @@ class Invite(Hashable): self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0)) application = data.get('target_application') - self.target_application: Optional[PartialAppInfo] = ( - PartialAppInfo(data=application, state=state) if application else None - ) + if application is not None: + from .appinfo import PartialApplication + application = PartialApplication(data=application, state=state) + self.target_application: Optional[PartialApplication] = application + + self.welcome_screen = welcome_screen @classmethod def from_incomplete( - cls: Type[I], *, state: ConnectionState, data: InvitePayload, message_id: Optional[int] = None + cls: Type[I], *, state: ConnectionState, data: InvitePayload, message: Optional[Message] = None ) -> I: guild: Optional[Union[Guild, PartialInviteGuild]] try: @@ -378,33 +405,34 @@ class Invite(Hashable): except KeyError: # If we're here, then this is a group DM guild = None + welcome_screen = None else: guild_id = int(guild_data['id']) guild = state._get_guild(guild_id) if guild is None: guild = PartialInviteGuild(state, guild_data, guild_id) - # As far as I know, invites always need a channel - channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data['channel']) - if guild is not None and not isinstance(guild, PartialInviteGuild): - # Upgrade the partial data if applicable - channel = guild.get_channel(channel.id) or channel + welcome_screen = guild_data.get('welcome_screen') + if welcome_screen is not None: + welcome_screen = WelcomeScreen(data=welcome_screen, guild=guild) - if message_id is not None: - data['message_id'] = message_id + channel = PartialInviteChannel(data.get('channel')) + channel = state.get_channel(getattr(channel, 'id', None)) or channel - return cls(state=state, data=data, guild=guild, channel=channel) + if message is not None: + data['message'] = message + + return cls(state=state, data=data, guild=guild, channel=channel, welcome_screen=welcome_screen) @classmethod def from_gateway(cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I: guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id') - guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id) - channel_id = int(data['channel_id']) - if guild is not None: - channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore - else: - guild = Object(id=guild_id) if guild_id is not None else None - channel = Object(id=channel_id) + + channel_id = _get_as_snowflake(data, 'channel_id') + if guild_id is not None: + guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id) or Object(id=guild_id) + if channel_id is not None: + channel: Optional[InviteChannelType] = state.get_channel(channel_id) or Object(id=channel_id) # type: ignore return cls(state=state, data=data, guild=guild, channel=channel) # type: ignore @@ -440,8 +468,8 @@ class Invite(Hashable): def __repr__(self) -> str: return ( - f'' ) @@ -458,30 +486,68 @@ class Invite(Hashable): """:class:`str`: A property that retrieves the invite URL.""" return self.BASE + '/' + self.code - async def use(self) -> Guild: + async def use(self) -> Union[Guild, User, GroupChannel]: """|coro| - Uses the invite (joins the guild) + Uses the invite. + Either joins a guild, joins a group DM, or adds a friend. + + There is an alias for this called :func:`accept`. .. versionadded:: 1.9 Raises ------ :exc:`.HTTPException` - Joining the guild failed. + Using the invite failed. Returns ------- - :class:`.Guild` - The guild joined. This is not the same guild that is - added to cache. + Union[:class:`Guild`, :class:`User`, :class:`GroupChannel`] + The guild/group DM joined, or user added as a friend. """ - state = self._state - data = await state.http.join_guild(self.code, self.guild.id, self.channel.id, self.channel.type.value, self._message_id) - return state.Guild(data=data['guild'], state=state) # Circular import + type = self.type + if (message := self._message): + kwargs = {'message': message} + else: + kwargs = { + 'guild_id': getattr(self.guild, 'id', MISSING), + 'channel_id': getattr(self.channel, 'id', MISSING), + 'channel_type': getattr(self.channel, 'type', MISSING), + } + data = await state.http.accept_invite(self.code, type, **kwargs) + if type is InviteType.guild: + from .guild import Guild + return Guild(data=data['guild'], state=state) + elif type is InviteType.group_dm: + from .channel import GroupChannel + return GroupChannel(data=data['channel'], state=state, me=state.user) # type: ignore + else: + from .user import User + return User(data=data['inviter'], state=state) + + async def accept(self) -> Union[Guild, User, GroupChannel]: + """|coro| + + Uses the invite. + Either joins a guild, joins a group DM, or adds a friend. + + This is an alias of :func:`use`. + + .. versionadded:: 1.9 - accept = use + Raises + ------ + :exc:`.HTTPException` + Using the invite failed. + + Returns + ------- + Union[:class:`Guild`, :class:`User`, :class:`GroupChannel`] + The guild/group DM joined, or user added as a friend. + """ + return await self.use() async def delete(self, *, reason: Optional[str] = None): """|coro| @@ -504,5 +570,4 @@ class Invite(Hashable): HTTPException Revoking the invite failed. """ - await self._state.http.delete_invite(self.code, reason=reason)