diff --git a/discord/http.py b/discord/http.py index 6f10e6a40..988aa0a19 100644 --- a/discord/http.py +++ b/discord/http.py @@ -29,7 +29,6 @@ from base64 import b64encode import json import logging from random import choice, getrandbits -import sys from typing import ( Any, ClassVar, @@ -40,7 +39,6 @@ from typing import ( Optional, Sequence, TYPE_CHECKING, - Tuple, Type, TypeVar, Union, @@ -51,7 +49,7 @@ import weakref import aiohttp from .enums import RelationshipAction -from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordServerError, GatewayNotFound, InvalidArgument +from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordServerError, InvalidArgument from . import utils from .tracking import ContextProperties from .utils import MISSING @@ -66,12 +64,10 @@ if TYPE_CHECKING: appinfo, audit_log, channel, - components, emoji, embed, guild, integration, - interactions, invite, member, message, @@ -79,11 +75,8 @@ if TYPE_CHECKING: role, user, webhook, - channel, widget, threads, - voice, - snowflake, sticker, ) from .types.snowflake import Snowflake, SnowflakeList @@ -446,7 +439,7 @@ class HTTPClient: self.token = token self.ack_token = None - def get_me(self, with_analytics_token=True) -> user.User: + def get_me(self, with_analytics_token=True) -> Response[user.User]: params = { 'with_analytics_token': str(with_analytics_token).lower() } @@ -496,7 +489,7 @@ class HTTPClient: return self.request(Route('PATCH', '/channels/{channel_id}', channel_id=channel_id), json=payload) - def get_private_channels(self) -> Response[List[channel.PrivateChannel]]: + def get_private_channels(self) -> Response[List[Union[channel.DMChannel, channel.GroupDMChannel]]]: return self.request(Route('GET', '/users/@me/channels')) def start_private_message(self, user_id: Snowflake) -> Response[channel.DMChannel]: @@ -517,13 +510,13 @@ class HTTPClient: tts: bool = False, embed: Optional[embed.Embed] = None, embeds: Optional[List[embed.Embed]] = None, - nonce: Optional[Union[int, str]] = None, + nonce: Optional[Snowflake] = None, allowed_mentions: Optional[message.AllowedMentions] = None, message_reference: Optional[message.MessageReference] = None, stickers: Optional[List[sticker.StickerItem]] = None, ) -> Response[message.Message]: r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) - payload = {'tts': tts} + payload: Dict[str, Any] = {'tts': tts} if content: payload['content'] = content if embed: @@ -553,13 +546,13 @@ class HTTPClient: tts: bool = False, embed: Optional[embed.Embed] = None, embeds: Optional[Iterable[Optional[embed.Embed]]] = None, - nonce: Optional[str] = None, + nonce: Optional[Snowflake] = None, allowed_mentions: Optional[message.AllowedMentions] = None, message_reference: Optional[message.MessageReference] = None, stickers: Optional[List[sticker.StickerItem]] = None, ) -> Response[message.Message]: form = [] - payload = {'tts': tts} + payload: Dict[str, Any] = {'tts': tts} if content: payload['content'] = content @@ -605,7 +598,7 @@ class HTTPClient: tts: bool = False, embed: Optional[embed.Embed] = None, embeds: Optional[List[embed.Embed]] = None, - nonce: Optional[str] = None, + nonce: Optional[Snowflake] = None, allowed_mentions: Optional[message.AllowedMentions] = None, message_reference: Optional[message.MessageReference] = None, stickers: Optional[List[sticker.StickerItem]] = None, @@ -743,16 +736,17 @@ class HTTPClient: ) return self.request(r) - async def get_message(self, channel_id: Snowflake, message_id: Snowflake) -> Response[message.Message]: + async def get_message(self, channel_id: Snowflake, message_id: Snowflake) -> message.Message: data = await self.logs_from(channel_id, 1, around=message_id) + try: msg = data[0] except IndexError: raise NotFound(_FakeResponse('Not Found', 404), 'message not found') + if int(msg.get('id')) != message_id: + raise NotFound(_FakeResponse('Not Found', 404), 'message not found') - if int(msg.get('id')) == message_id: - return msg - raise NotFound(_FakeResponse('Not Found', 404), 'message not found') + return msg def get_channel(self, channel_id: Snowflake) -> Response[channel.Channel]: return self.request(Route('GET', '/channels/{channel_id}', channel_id=channel_id)) @@ -1132,8 +1126,8 @@ class HTTPClient: after: Optional[Snowflake] = None, with_counts: bool = True ) -> Response[List[guild.Guild]]: - params = { - 'with_counts': with_counts + params: Dict[str, Snowflake] = { + 'with_counts': str(with_counts).lower() } if limit and limit != 200: params['limit'] = limit @@ -1291,9 +1285,9 @@ class HTTPClient: return self.request(Route('GET', '/stickers/{sticker_id}', sticker_id=sticker_id)) def list_premium_sticker_packs( - self, country: str = 'US', locale: str = 'en-US', payment_source_id: int = MISSING + self, country: str = 'US', locale: str = 'en-US', payment_source_id: Snowflake = MISSING ) -> Response[sticker.ListPremiumStickerPacks]: - params = { + params: Dict[str, Snowflake] = { 'country_code': country, 'locale': locale, } @@ -1469,7 +1463,7 @@ class HTTPClient: action_type: Optional[AuditLogAction] = None, ) -> Response[audit_log.AuditLog]: r = Route('GET', '/guilds/{guild_id}/audit-logs', guild_id=guild_id) - params = { + params: Dict[str, Any] = { 'limit': limit } if before: @@ -1730,7 +1724,7 @@ class HTTPClient: elif action == RelationshipAction.remove_pending_request: # Friends props = ContextProperties._from_friends_page() - return self.request(r, context_properties=props) + return self.request(r, context_properties=props) # type: ignore def add_relationship( self, user_id: Snowflake, type: int = MISSING, *, action: RelationshipAction @@ -1756,7 +1750,7 @@ class HTTPClient: ContextProperties._from_dm_channel() )) kwargs = { - 'context_properties': props + 'context_properties': props # type: ignore } if type: kwargs['json'] = {'type': type}