Browse Source

Fix group/friend invites

pull/10109/head
dolfies 4 years ago
parent
commit
0e0ff384f7
  1. 153
      discord/invite.py

153
discord/invite.py

@ -26,11 +26,11 @@ from __future__ import annotations
from typing import List, Optional, Type, TypeVar, Union, TYPE_CHECKING from typing import List, Optional, Type, TypeVar, Union, TYPE_CHECKING
from .asset import Asset from .asset import Asset
from .utils import parse_time, snowflake_time, _get_as_snowflake from .utils import parse_time, snowflake_time, _get_as_snowflake, MISSING
from .object import Object from .object import Object
from .mixins import Hashable from .mixins import Hashable
from .enums import ChannelType, VerificationLevel, InviteTarget, try_enum from .enums import ChannelType, VerificationLevel, InviteTarget, InviteType, try_enum
from .appinfo import PartialAppInfo from .welcome_screen import WelcomeScreen
__all__ = ( __all__ = (
'PartialInviteChannel', 'PartialInviteChannel',
@ -49,11 +49,14 @@ if TYPE_CHECKING:
) )
from .state import ConnectionState from .state import ConnectionState
from .guild import Guild from .guild import Guild
from .abc import GuildChannel from .abc import GuildChannel, PrivateChannel
from .user import User from .user import User
from .appinfo import PartialApplication
from .message import Message
from .channel import GroupChannel
InviteGuildType = Union[Guild, 'PartialInviteGuild', Object] InviteGuildType = Union[Guild, 'PartialInviteGuild', Object]
InviteChannelType = Union[GuildChannel, 'PartialInviteChannel', Object] InviteChannelType = Union[GuildChannel, 'PartialInviteChannel', Object, PrivateChannel]
import datetime import datetime
@ -94,7 +97,14 @@ class PartialInviteChannel:
__slots__ = ('id', 'name', 'type') __slots__ = ('id', 'name', 'type')
def __init__(self, data: InviteChannelPayload): def __new__(cls, data: Optional[InviteChannelPayload]):
if data is None:
return
return super().__new__(cls)
def __init__(self, data: Optional[InviteChannelPayload]):
if data is None:
return
self.id: int = int(data['id']) self.id: int = int(data['id'])
self.name: str = data['name'] self.name: str = data['name']
self.type: ChannelType = try_enum(ChannelType, data['type']) self.type: ChannelType = try_enum(ChannelType, data['type'])
@ -261,8 +271,13 @@ class Invite(Hashable):
A value of ``0`` indicates that it doesn't expire. A value of ``0`` indicates that it doesn't expire.
code: :class:`str` code: :class:`str`
The URL fragment used for the invite. The URL fragment used for the invite.
type: :class:`InviteType`
The type of invite.
.. versionadded:: 2.0
guild: Optional[Union[:class:`Guild`, :class:`Object`, :class:`PartialInviteGuild`]] guild: Optional[Union[:class:`Guild`, :class:`Object`, :class:`PartialInviteGuild`]]
The guild the invite is for. Can be ``None`` if it's from a group direct message. The guild the invite is for. Can be ``None`` if not a guild invite.
revoked: :class:`bool` revoked: :class:`bool`
Indicates if the invite has been revoked. Indicates if the invite has been revoked.
created_at: :class:`datetime.datetime` created_at: :class:`datetime.datetime`
@ -288,8 +303,8 @@ class Invite(Hashable):
.. versionadded:: 2.0 .. versionadded:: 2.0
channel: Union[:class:`abc.GuildChannel`, :class:`Object`, :class:`PartialInviteChannel`] channel: Optional[Union[:class:`abc.GuildChannel`, :class:`Object`, :class:`PartialInviteChannel`]]
The channel the invite is for. The channel the invite is for. Can be ``None`` if not a guild invite.
target_type: :class:`InviteTarget` target_type: :class:`InviteTarget`
The type of target for the voice channel invite. The type of target for the voice channel invite.
@ -300,9 +315,14 @@ class Invite(Hashable):
.. versionadded:: 2.0 .. versionadded:: 2.0
target_application: Optional[:class:`PartialAppInfo`] target_application: Optional[:class:`PartialApplication`]
The embedded application the invite targets, if any. The embedded application the invite targets, if any.
.. versionadded:: 2.0
welcome_screen: Optional[:class:`WelcomeScreen`]
The guild's welcome screen, if available.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
@ -324,7 +344,9 @@ class Invite(Hashable):
'approximate_presence_count', 'approximate_presence_count',
'target_application', 'target_application',
'expires_at', 'expires_at',
'_message_id', '_message',
'welcome_screen',
'type',
) )
BASE = 'https://discord.gg' BASE = 'https://discord.gg'
@ -336,8 +358,10 @@ class Invite(Hashable):
data: InvitePayload, data: InvitePayload,
guild: Optional[Union[PartialInviteGuild, Guild]] = None, guild: Optional[Union[PartialInviteGuild, Guild]] = None,
channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None, channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None,
welcome_screen: Optional[WelcomeScreen] = None,
): ):
self._state: ConnectionState = state self._state: ConnectionState = state
self.type: InviteType = try_enum(InviteType, data.get('type', 0))
self.max_age: Optional[int] = data.get('max_age') self.max_age: Optional[int] = data.get('max_age')
self.code: str = data['code'] self.code: str = data['code']
self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get('guild'), guild) self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get('guild'), guild)
@ -348,7 +372,7 @@ class Invite(Hashable):
self.max_uses: Optional[int] = data.get('max_uses') self.max_uses: Optional[int] = data.get('max_uses')
self.approximate_presence_count: Optional[int] = data.get('approximate_presence_count') self.approximate_presence_count: Optional[int] = data.get('approximate_presence_count')
self.approximate_member_count: Optional[int] = data.get('approximate_member_count') self.approximate_member_count: Optional[int] = data.get('approximate_member_count')
self._message_id: Optional[int] = data.get('message_id') self._message: Optional[Message] = data.get('message')
expires_at = data.get('expires_at', None) expires_at = data.get('expires_at', None)
self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None
@ -364,13 +388,16 @@ class Invite(Hashable):
self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0)) self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0))
application = data.get('target_application') application = data.get('target_application')
self.target_application: Optional[PartialAppInfo] = ( if application is not None:
PartialAppInfo(data=application, state=state) if application else None from .appinfo import PartialApplication
) application = PartialApplication(data=application, state=state)
self.target_application: Optional[PartialApplication] = application
self.welcome_screen = welcome_screen
@classmethod @classmethod
def from_incomplete( def from_incomplete(
cls: Type[I], *, state: ConnectionState, data: InvitePayload, message_id: Optional[int] = None cls: Type[I], *, state: ConnectionState, data: InvitePayload, message: Optional[Message] = None
) -> I: ) -> I:
guild: Optional[Union[Guild, PartialInviteGuild]] guild: Optional[Union[Guild, PartialInviteGuild]]
try: try:
@ -378,33 +405,34 @@ class Invite(Hashable):
except KeyError: except KeyError:
# If we're here, then this is a group DM # If we're here, then this is a group DM
guild = None guild = None
welcome_screen = None
else: else:
guild_id = int(guild_data['id']) guild_id = int(guild_data['id'])
guild = state._get_guild(guild_id) guild = state._get_guild(guild_id)
if guild is None: if guild is None:
guild = PartialInviteGuild(state, guild_data, guild_id) guild = PartialInviteGuild(state, guild_data, guild_id)
# As far as I know, invites always need a channel welcome_screen = guild_data.get('welcome_screen')
channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data['channel']) if welcome_screen is not None:
if guild is not None and not isinstance(guild, PartialInviteGuild): welcome_screen = WelcomeScreen(data=welcome_screen, guild=guild)
# Upgrade the partial data if applicable
channel = guild.get_channel(channel.id) or channel
if message_id is not None: channel = PartialInviteChannel(data.get('channel'))
data['message_id'] = message_id channel = state.get_channel(getattr(channel, 'id', None)) or channel
return cls(state=state, data=data, guild=guild, channel=channel) if message is not None:
data['message'] = message
return cls(state=state, data=data, guild=guild, channel=channel, welcome_screen=welcome_screen)
@classmethod @classmethod
def from_gateway(cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I: def from_gateway(cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I:
guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id') guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id')
guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id)
channel_id = int(data['channel_id']) channel_id = _get_as_snowflake(data, 'channel_id')
if guild is not None: if guild_id is not None:
channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id) or Object(id=guild_id)
else: if channel_id is not None:
guild = Object(id=guild_id) if guild_id is not None else None channel: Optional[InviteChannelType] = state.get_channel(channel_id) or Object(id=channel_id) # type: ignore
channel = Object(id=channel_id)
return cls(state=state, data=data, guild=guild, channel=channel) # type: ignore return cls(state=state, data=data, guild=guild, channel=channel) # type: ignore
@ -440,8 +468,8 @@ class Invite(Hashable):
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f'<Invite code={self.code!r} guild={self.guild!r} ' f'<Invite code={self.code!r} type={self.type!r} '
f'online={self.approximate_presence_count} ' f'guild={self.guild!r} '
f'members={self.approximate_member_count}>' f'members={self.approximate_member_count}>'
) )
@ -458,30 +486,68 @@ class Invite(Hashable):
""":class:`str`: A property that retrieves the invite URL.""" """:class:`str`: A property that retrieves the invite URL."""
return self.BASE + '/' + self.code return self.BASE + '/' + self.code
async def use(self) -> Guild: async def use(self) -> Union[Guild, User, GroupChannel]:
"""|coro| """|coro|
Uses the invite (joins the guild) Uses the invite.
Either joins a guild, joins a group DM, or adds a friend.
There is an alias for this called :func:`accept`.
.. versionadded:: 1.9 .. versionadded:: 1.9
Raises Raises
------ ------
:exc:`.HTTPException` :exc:`.HTTPException`
Joining the guild failed. Using the invite failed.
Returns Returns
------- -------
:class:`.Guild` Union[:class:`Guild`, :class:`User`, :class:`GroupChannel`]
The guild joined. This is not the same guild that is The guild/group DM joined, or user added as a friend.
added to cache.
""" """
state = self._state state = self._state
data = await state.http.join_guild(self.code, self.guild.id, self.channel.id, self.channel.type.value, self._message_id) type = self.type
return state.Guild(data=data['guild'], state=state) # Circular import if (message := self._message):
kwargs = {'message': message}
else:
kwargs = {
'guild_id': getattr(self.guild, 'id', MISSING),
'channel_id': getattr(self.channel, 'id', MISSING),
'channel_type': getattr(self.channel, 'type', MISSING),
}
data = await state.http.accept_invite(self.code, type, **kwargs)
if type is InviteType.guild:
from .guild import Guild
return Guild(data=data['guild'], state=state)
elif type is InviteType.group_dm:
from .channel import GroupChannel
return GroupChannel(data=data['channel'], state=state, me=state.user) # type: ignore
else:
from .user import User
return User(data=data['inviter'], state=state)
async def accept(self) -> Union[Guild, User, GroupChannel]:
"""|coro|
Uses the invite.
Either joins a guild, joins a group DM, or adds a friend.
This is an alias of :func:`use`.
.. versionadded:: 1.9
accept = use Raises
------
:exc:`.HTTPException`
Using the invite failed.
Returns
-------
Union[:class:`Guild`, :class:`User`, :class:`GroupChannel`]
The guild/group DM joined, or user added as a friend.
"""
return await self.use()
async def delete(self, *, reason: Optional[str] = None): async def delete(self, *, reason: Optional[str] = None):
"""|coro| """|coro|
@ -504,5 +570,4 @@ class Invite(Hashable):
HTTPException HTTPException
Revoking the invite failed. Revoking the invite failed.
""" """
await self._state.http.delete_invite(self.code, reason=reason) await self._state.http.delete_invite(self.code, reason=reason)

Loading…
Cancel
Save