From 1a1b9cf15a8570b2f702db321483e26634961ad6 Mon Sep 17 00:00:00 2001 From: dolfies Date: Mon, 6 Dec 2021 14:16:46 -0500 Subject: [PATCH] Implement connections, add fetch_sticker_pack, fix some small issues --- discord/activity.py | 2 +- discord/client.py | 58 ++++++++++++++--- discord/connections.py | 139 +++++++++++++++++++++++++++++++++++++++++ discord/guild.py | 4 +- discord/http.py | 20 +++++- discord/profile.py | 8 +-- discord/state.py | 3 + 7 files changed, 215 insertions(+), 19 deletions(-) create mode 100644 discord/connections.py diff --git a/discord/activity.py b/discord/activity.py index 999c46dce..9754e7013 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -744,7 +744,7 @@ class CustomActivity(BaseActivity): super().__init__(**extra) self.name: Optional[str] = name self.state = state = extra.pop('state', None) - if self.name == 'Custom Activity': + if self.name == 'Custom Status': self.name = state self.emoji: Optional[PartialEmoji] diff --git a/discord/client.py b/discord/client.py index a63b9b26f..7cf1a68ce 100644 --- a/discord/client.py +++ b/discord/client.py @@ -60,6 +60,7 @@ from .stage_instance import StageInstance from .threads import Thread from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory from .profile import UserProfile +from .connections import Connection if TYPE_CHECKING: from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake @@ -1098,7 +1099,7 @@ class Client: The preferred region to connect to. """ state = self._connection - ws = state._get_websocket(self.id) + ws = self.ws channel_id = channel.id if channel else None if preferred_region is None or channel_id is None: @@ -1106,7 +1107,7 @@ class Client: else: region = str(preferred_region) if preferred_region else str(state.preferred_region) - await ws.voice_state(None, channel_id, self_mute, self_deaf, self_video, region) + await ws.voice_state(None, channel_id, self_mute, self_deaf, self_video, preferred_region=region) # Guild stuff @@ -1233,7 +1234,6 @@ class Client: self, *, name: str, - region: Union[VoiceRegion, str] = VoiceRegion.us_west, icon: bytes = MISSING, code: str = MISSING, ) -> Guild: @@ -1402,7 +1402,7 @@ class Client: if not isinstance(invite, Invite): invite = await self.fetch_invite(invite, with_counts=False, with_expiration=False) - data = await self.http.join_guild(invite.code, guild_id=invite.guild.id, channel_id=invite.channel.id, channel_type=invite.channel.type.value) + data = await self.http.accept_invite(invite.code, guild_id=invite.guild.id, channel_id=invite.channel.id, channel_type=invite.channel.type.value) return Guild(data=data['guild'], state=self._connection) use_invite = accept_invite @@ -1469,7 +1469,6 @@ class Client: data = await self.http.get_user(user_id) return User(state=self._connection, data=data) - async def fetch_user_profile( self, user_id: int, /, *, with_mutuals: bool = True, fetch_note: bool = True ) -> UserProfile: @@ -1553,10 +1552,10 @@ class Client: raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data)) if ch_type in (ChannelType.group, ChannelType.private): - # the factory will be a DMChannel or GroupChannel here + # The factory will be a DMChannel or GroupChannel here channel = factory(me=self.user, data=data, state=self._connection) # type: ignore else: - # the factory can't be a DMChannel or GroupChannel here + # The factory can't be a DMChannel or GroupChannel here guild_id = int(data['guild_id']) # type: ignore guild = self.get_guild(guild_id) or Object(id=guild_id) # GuildChannels expect a Guild, we may be passing an Object @@ -1637,12 +1636,32 @@ class Client: Returns --------- List[:class:`.StickerPack`] - All available premium sticker packs. + All available sticker packs. """ data = await self.http.list_premium_sticker_packs(country, locale, payment_source_id) return [StickerPack(state=self._connection, data=pack) for pack in data['sticker_packs']] - async def fetch_notes(self) -> List[Note]: + async def fetch_sticker_pack(self, pack_id: int, /): + """|coro| + + Retrieves a sticker pack with the specified ID. + + Raises + ------- + :exc:`.NotFound` + A sticker pack with that ID was not found. + :exc:`.HTTPException` + Retrieving the sticker packs failed. + + Returns + ------- + :class:`.StickerPack` + The sticker pack you requested. + """ + data = await self.http.get_sticker_pack(pack_id) + return StickerPack(state=self._connection, data=data) + + async def notes(self) -> List[Note]: """|coro| Retrieves a list of :class:`Note` objects representing all your notes. @@ -1685,6 +1704,25 @@ class Client: await note.fetch() return note + async def connections(self) -> List[Connection]: + """|coro| + + Retrieves all of your connections. + + Raises + ------- + :exc:`.HTTPException` + Retreiving your connections failed. + + Returns + ------- + List[:class:`Connection`] + All your connections. + """ + state = self._connection + data = await state.http.get_connections() + return [Connection(data=d, state=state) for d in data] + async def fetch_private_channels(self) -> List[PrivateChannel]: """|coro| @@ -1702,7 +1740,7 @@ class Client: """ state = self._connection channels = await state.http.get_private_channels() - return [_private_channel_factory(data['type'])(me=self.user, data=data, state=state) for data in channels] + return [_private_channel_factory(data['type'])[0](me=self.user, data=data, state=state) for data in channels] async def create_dm(self, user: Snowflake) -> DMChannel: """|coro| diff --git a/discord/connections.py b/discord/connections.py new file mode 100644 index 000000000..b5b886628 --- /dev/null +++ b/discord/connections.py @@ -0,0 +1,139 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Dolfies + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import Optional + +from .utils import MISSING + + +class PartialConnection: + """Represents a partial Discord profile connection + + This is the info you get for other people's connections. + + Attributes + ---------- + id: :class:`str` + The connection's account ID. + name: :class:`str` + The connection's account name. + type: :class:`str` + The connection service (e.g. 'youtube') + verified: :class:`bool` + Whether the connection is verified. + revoked: :class:`bool` + Whether the connection is revoked. + visible: :class:`bool` + Whether the connection is visible on the user's profile. + """ + + __slots__ = ('id', 'name', 'type', 'verified', 'revoked', 'visible') + + def __init__(self, data): + self.id: str = data['id'] + self.name: str = data['name'] + self.type: str = data['type'] + + self.verified: bool = data['verified'] + self.revoked: bool = data.get('revoked', False) + self.visible: bool = True + + +class Connection(PartialConnection): + """Represents a Discord profile connection + + Attributes + ---------- + id: :class:`str` + The connection's account ID. + name: :class:`str` + The connection's account name. + type: :class:`str` + The connection service (e.g. 'youtube') + verified: :class:`bool` + Whether the connection is verified. + revoked: :class:`bool` + Whether the connection is revoked. + visible: :class:`bool` + Whether the connection is visible on the user's profile. + friend_sync: :class:`bool` + Whether friends are synced over the connection. + show_activity: :class:`bool` + Whether activities from this connection will be shown in presences. + access_token: :class:`str` + The OAuth2 access token for the account, if applicable. + """ + + __slots__ = ('_state', 'visible', 'friend_sync', 'show_activity', 'access_token') + + def __init__(self, *, data, state): + self._state = state + super().__init__(data) + + self.visible: bool = bool(data.get('visibility', True)) + self.friend_sync: bool = data.get('friend_sync', False) + self.show_activity: bool = data.get('show_activity', True) + self.access_token: Optional[str] = data.get('access_token') + + async def edit(self, *, visible: bool = MISSING): + """|coro| + + Edit the connection. + + All parameters are optional. + + Parameters + ---------- + visible: :class:`bool` + Whether the connection is visible on your profile. + + Raises + ------ + HTTPException + Editing the connection failed. + + Returns + ------- + :class:`Connection` + The new connection. + """ + if visible is not MISSING: + data = await self._state.http.edit_connection(self.type, self.id, visibility=visible) + return Connection(data=data, state=self._state) + else: + return self + + async def delete(self): + """|coro| + + Removes the connection. + + Raises + ------ + HTTPException + Deleting the connection failed. + """ + await self._state.http.delete_connection(self.type, self.id) diff --git a/discord/guild.py b/discord/guild.py index 50d8e8bd9..b265b4ad9 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -2944,7 +2944,7 @@ class Guild(Hashable): The preferred region to connect to. """ state = self._state - ws = state._get_websocket(self.id) + ws = state.ws channel_id = channel.id if channel else None if preferred_region is None or channel_id is None: @@ -2952,4 +2952,4 @@ class Guild(Hashable): else: region = str(preferred_region) if preferred_region else str(state.preferred_region) - await ws.voice_state(self.id, channel_id, self_mute, self_deaf, self_video, region) + await ws.voice_state(self.id, channel_id, self_mute, self_deaf, self_video, preferred_region=region) diff --git a/discord/http.py b/discord/http.py index 0523b1e79..c00bf38f1 100644 --- a/discord/http.py +++ b/discord/http.py @@ -283,12 +283,13 @@ class HTTPClient: 'Sec-Fetch-Mode': 'cors', 'Sec-Fetch-Site': 'same-origin', 'User-Agent': self.user_agent, + 'X-Discord-Locale': 'en-US', 'X-Debug-Options': 'bugReporterEnabled', 'X-Super-Properties': self.encoded_super_properties } # Header modification - if self.token is not None: + if self.token is not None and kwargs.get('auth', True): headers['Authorization'] = self.token reason = kwargs.pop('reason', None) @@ -1296,6 +1297,9 @@ class HTTPClient: return self.request(Route('GET', '/sticker-packs'), params=params) + def get_sticker_pack(self, pack_id: Snowflake): + return self.request(Route('GET', '/sticker-packs/{pack_id}', pack_id=pack_id), auth=False) + def get_all_guild_stickers(self, guild_id: Snowflake) -> Response[List[sticker.GuildSticker]]: return self.request(Route('GET', '/guilds/{guild_id}/stickers', guild_id=guild_id)) @@ -1834,6 +1838,15 @@ class HTTPClient: def edit_settings(self, **payload): # TODO: return type, is this cheating? return self.request(Route('PATCH', '/users/@me/settings'), json=payload) + def get_connections(self): + return self.request(Route('GET', '/users/@me/connections')) + + def edit_connection(self, type, id, **payload): + return self.request(Route('PATCH', '/users/@me/connections/{type}/{id}', type=type, id=id), json=payload) + + def delete_connection(self, type, id): + return self.request(Route('DELETE', '/users/@me/connections/{type}/{id}', type=type, id=id)) + def get_applications(self, *, with_team_applications: bool = True) -> Response[List[appinfo.AppInfo]]: params = { 'with_team_applications': str(with_team_applications).lower() @@ -1844,6 +1857,9 @@ class HTTPClient: def get_my_application(self, app_id: Snowflake) -> Response[appinfo.AppInfo]: return self.request(Route('GET', '/applications/{app_id}', app_id=app_id), super_properties_to_track=True) + def get_partial_application(self, app_id: Snowflake): + return self.request(Route('GET', '/applications/{app_id}/rpc', app_id=app_id), auth=False) + def get_app_entitlements(self, app_id: Snowflake): # TODO: return type r = Route('GET', '/users/@me/applications/{app_id}/entitlements', app_id=app_id) return self.request(r, super_properties_to_track=True) @@ -1868,7 +1884,7 @@ class HTTPClient: def get_team(self, team_id: Snowflake): # TODO: return type return self.request(Route('GET', '/teams/{team_id}', team_id=team_id), super_properties_to_track=True) - def mobile_report( + def mobile_report( # Report v1 self, guild_id: Snowflake, channel_id: Snowflake, message_id: Snowflake, reason: str ): # TODO: return type payload = { diff --git a/discord/profile.py b/discord/profile.py index 5a88411dc..778597d8e 100644 --- a/discord/profile.py +++ b/discord/profile.py @@ -24,8 +24,9 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import List, Optional, Protocol, TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING +from .connections import PartialConnection from .flags import PrivateUserFlags from .member import Member from .user import Note, User @@ -50,9 +51,8 @@ class Profile: ``None`` if the user is not a premium user. boosting_since: Optional[:class:`datetime.datetime`] An aware datetime object that specifies when a user first boosted a guild. - connected_accounts: Optional[List[:class:`dict`]] + connections: Optional[List[:class:`PartialConnection`]] The connected accounts that show up on the profile. - These are currently just the raw json, but this will change in the future. note: :class:`Note` Represents the note on the profile. mutual_guilds: Optional[List[:class:`Guild`]] @@ -85,7 +85,7 @@ class Profile: self.premium_since: Optional[datetime] = parse_time(data['premium_since']) self.boosting_since: Optional[datetime] = parse_time(data['premium_guild_since']) - self.connected_accounts: List[dict] = data['connected_accounts'] # TODO: parse these + self.connections: List[PartialConnection] = [PartialConnection(d) for d in data['connection_accounts']] # TODO: parse these self.mutual_guilds: Optional[List[Guild]] = self._parse_mutual_guilds(data.get('mutual_guilds')) self.mutual_friends: Optional[List[User]] = self._parse_mutual_friends(data.get('mutual_friends')) diff --git a/discord/state.py b/discord/state.py index cad92dc19..03b2527f9 100644 --- a/discord/state.py +++ b/discord/state.py @@ -672,6 +672,9 @@ class ConnectionState: self.settings = UserSettings(data=data.get('user_settings', {}), state=self) self.consents = Tracking(data.get('consents', {})) + if 'required_action' in data: # Locked more than likely + self.parse_user_required_action_update(data) + # We're done del self._ready_data self.call_handlers('connect')