From 61e340981f95435a02aa6a7d7bca06dc661b039c Mon Sep 17 00:00:00 2001 From: dolfies Date: Sat, 27 May 2023 12:45:03 -0400 Subject: [PATCH] Implement guest invites --- discord/client.py | 2 +- discord/flags.py | 70 +++++++++++++++++++++++++++++++++++++++++++++++ discord/guild.py | 2 +- discord/http.py | 9 +++++- discord/invite.py | 57 ++++++++++++++++++++++---------------- docs/api.rst | 5 ++++ 6 files changed, 118 insertions(+), 27 deletions(-) diff --git a/discord/client.py b/discord/client.py index ea0c061b3..22745f6a9 100644 --- a/discord/client.py +++ b/discord/client.py @@ -2052,7 +2052,7 @@ class Client: 'channel_id': getattr(invite.channel, 'id', MISSING), 'channel_type': getattr(invite.channel, 'type', MISSING), } - data = await state.http.accept_invite(invite.code, type, **kwargs) + data = await state.http.accept_invite(invite.code, type, state.session_id or utils._generate_session_id(), **kwargs) return Invite.from_incomplete(state=state, data=data, message=invite._message) async def delete_invite(self, invite: Union[Invite, str], /) -> Invite: diff --git a/discord/flags.py b/discord/flags.py index 0ea5962d4..4e402ffb7 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -58,6 +58,7 @@ __all__ = ( 'AutoModPresets', 'MemberFlags', 'ReadStateFlags', + 'InviteFlags', ) BF = TypeVar('BF', bound='BaseFlags') @@ -2398,6 +2399,15 @@ class MemberFlags(BaseFlags): """:class:`bool`: Returns ``True`` if the member has started onboarding.""" return 1 << 3 + @flag_value + def guest(self): + """:class:`bool`: Returns ``True`` if the member is a guest. + Guest members are members that joined through a guest invite, and are not full members of the guild. + + .. versionadded:: 2.1 + """ + return 1 << 4 + @fill_with_flags() class ReadStateFlags(BaseFlags): @@ -2462,3 +2472,63 @@ class ReadStateFlags(BaseFlags): def thread(self): """:class:`bool`: Returns ``True`` if the read state is for a thread.""" return 1 << 1 + + +@fill_with_flags() +class InviteFlags(BaseFlags): + r"""Wraps up the Discord invite flags. + + .. container:: operations + + .. describe:: x == y + + Checks if two InviteFlags are equal. + + .. describe:: x != y + + Checks if two InviteFlags are not equal. + + .. describe:: x | y, x |= y + + Returns a InviteFlags instance with all enabled flags from + both x and y. + + .. describe:: x & y, x &= y + + Returns a InviteFlags instance with only flags enabled on + both x and y. + + .. describe:: x ^ y, x ^= y + + Returns a InviteFlags instance with only flags enabled on + only one of x or y, not on both. + + .. describe:: ~x + + Returns a InviteFlags instance with all flags inverted from x. + + .. describe:: hash(x) + + Return the flag's hash. + + .. describe:: iter(x) + + Returns an iterator of ``(name, value)`` pairs. This allows it + to be, for example, constructed as a dict or a list of pairs. + Note that aliases are not shown. + + .. versionadded:: 2.1 + + Attributes + ----------- + value: :class:`int` + The raw value. You should query flags via the properties + rather than using this raw value. + """ + + __slots__ = () + + @flag_value + def guest(self): + """:class:`bool`: Returns ``True`` if the invite is a guest invite. Guest invites grant temporary membership for the purposes of joining a voice channel.""" + return 1 << 0 diff --git a/discord/guild.py b/discord/guild.py index 39ddaaf5d..2f60e216d 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -768,7 +768,7 @@ class Guild(Hashable): return self.get_member(self_id) # type: ignore def is_joined(self) -> bool: - """Returns whether you are a member of this guild. + """Returns whether you are a full member of this guild. May not be accurate for :class:`Guild` s fetched over HTTP. diff --git a/discord/http.py b/discord/http.py index 1a43a8ad3..dcba9b333 100644 --- a/discord/http.py +++ b/discord/http.py @@ -2238,6 +2238,7 @@ class HTTPClient: self, invite_id: str, type: InviteType, + session_id: Optional[str] = None, *, guild_id: Snowflake = MISSING, channel_id: Snowflake = MISSING, @@ -2262,7 +2263,13 @@ class HTTPClient: props = ContextProperties.from_accept_invite_page( guild_id=guild_id, channel_id=channel_id, channel_type=channel_type ) - return self.request(Route('POST', '/invites/{invite_id}', invite_id=invite_id), context_properties=props, json={}) + payload = {} + if session_id is not None: + payload['session_id'] = session_id + + return self.request( + Route('POST', '/invites/{invite_id}', invite_id=invite_id), context_properties=props, json=payload + ) def create_invite( self, diff --git a/discord/invite.py b/discord/invite.py index 65f83f915..e582e4a3c 100644 --- a/discord/invite.py +++ b/discord/invite.py @@ -24,13 +24,15 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import List, Optional, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional, Union + from .asset import Asset -from .utils import parse_time, snowflake_time, _get_as_snowflake, MISSING -from .object import Object +from .enums import ChannelType, InviteTarget, InviteType, NSFWLevel, VerificationLevel, try_enum +from .flags import InviteFlags from .mixins import Hashable +from .object import Object from .scheduled_event import ScheduledEvent -from .enums import ChannelType, VerificationLevel, InviteTarget, InviteType, NSFWLevel, try_enum +from .utils import MISSING, _generate_session_id, _get_as_snowflake, parse_time, snowflake_time from .welcome_screen import WelcomeScreen __all__ = ( @@ -40,29 +42,27 @@ __all__ = ( ) if TYPE_CHECKING: + import datetime + from typing_extensions import Self + from .abc import GuildChannel, Snowflake + from .application import PartialApplication + from .channel import DMChannel, GroupChannel + from .guild import Guild + from .message import Message + from .state import ConnectionState + from .types.channel import PartialChannel as InviteChannelPayload from .types.invite import ( + GatewayInvite as GatewayInvitePayload, Invite as InvitePayload, InviteGuild as InviteGuildPayload, - GatewayInvite as GatewayInvitePayload, - ) - from .types.channel import ( - PartialChannel as InviteChannelPayload, ) - from .state import ConnectionState - from .guild import Guild - from .abc import GuildChannel, Snowflake - from .channel import DMChannel, GroupChannel from .user import User - from .application import PartialApplication - from .message import Message InviteGuildType = Union[Guild, 'PartialInviteGuild', Object] InviteChannelType = Union[GuildChannel, 'PartialInviteChannel', Object, DMChannel, GroupChannel] - import datetime - class PartialInviteChannel: """Represents a "partial" invite channel. @@ -97,7 +97,8 @@ class PartialInviteChannel: type: :class:`ChannelType` The partial channel's type. recipients: Optional[List[:class:`str`]] - The partial channel's recipient names. This is only applicable to group DMs. + The partial channel's recipient names. + This is applicable to channels of type :attr:`ChannelType.group`. .. versionadded:: 2.0 """ @@ -401,7 +402,7 @@ class Invite(Hashable): This is only possibly ``True`` in accepted invite objects (i.e. the objects received from :meth:`accept` and :meth:`use`). show_verification_form: :class:`bool` - Whether the user should be shown the guild's member verification form. + Whether the user should be shown the guild's membership screening form. .. versionadded:: 2.0 @@ -435,6 +436,7 @@ class Invite(Hashable): 'type', 'new_member', 'show_verification_form', + '_flags', ) BASE = 'https://discord.gg' @@ -447,6 +449,7 @@ class Invite(Hashable): guild: Optional[Union[PartialInviteGuild, Guild]] = None, channel: Optional[Union[PartialInviteChannel, GuildChannel, GroupChannel]] = None, welcome_screen: Optional[WelcomeScreen] = None, + message: Optional[Message] = None, ): self._state: ConnectionState = state self.type: InviteType = try_enum(InviteType, data.get('type', 0)) @@ -460,7 +463,8 @@ class Invite(Hashable): self.max_uses: Optional[int] = data.get('max_uses') self.approximate_presence_count: Optional[int] = data.get('approximate_presence_count') self.approximate_member_count: Optional[int] = data.get('approximate_member_count') - self._message: Optional[Message] = data.get('message') + self._flags: int = data.get('flags', 0) + self._message: Optional[Message] = message # We inject some missing data here since we can assume it if self.type in (InviteType.group_dm, InviteType.friend): @@ -530,10 +534,7 @@ class Invite(Hashable): channel = PartialInviteChannel(channel_data, state) channel = state.get_channel(getattr(channel, 'id', None)) or channel - if message is not None: - data['message'] = message # type: ignore # Not a real field - - return cls(state=state, data=data, guild=guild, channel=channel, welcome_screen=welcome_screen) # type: ignore + return cls(state=state, data=data, guild=guild, channel=channel, welcome_screen=welcome_screen, message=message) # type: ignore @classmethod def from_gateway(cls, *, state: ConnectionState, data: GatewayInvitePayload) -> Self: @@ -601,6 +602,14 @@ class Invite(Hashable): url += '?event=' + str(self.scheduled_event_id) return url + @property + def flags(self) -> InviteFlags: + """:class:`InviteFlags`: Returns the invite's flags. + + .. versionadded:: 2.1 + """ + return InviteFlags._from_value(self._flags) + def set_scheduled_event(self, scheduled_event: Snowflake, /) -> Self: """Sets the scheduled event for this invite. @@ -654,7 +663,7 @@ class Invite(Hashable): 'channel_id': getattr(self.channel, 'id', MISSING), 'channel_type': getattr(self.channel, 'type', MISSING), } - data = await state.http.accept_invite(self.code, type, **kwargs) + data = await state.http.accept_invite(self.code, type, state.session_id or _generate_session_id(), **kwargs) return Invite.from_incomplete(state=state, data=data, message=message) async def accept(self) -> Invite: diff --git a/docs/api.rst b/docs/api.rst index 96a60361b..209fca733 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -7779,6 +7779,11 @@ Flags .. autoclass:: HubProgressFlags() :members: +.. attributetable:: InviteFlags + +.. autoclass:: InviteFlags() + :members: + .. attributetable:: MemberFlags .. autoclass:: MemberFlags()