Browse Source

Fix typing errors

pull/10109/head
dolfies 4 years ago
parent
commit
cbab87d287
  1. 46
      discord/http.py

46
discord/http.py

@ -29,7 +29,6 @@ from base64 import b64encode
import json import json
import logging import logging
from random import choice, getrandbits from random import choice, getrandbits
import sys
from typing import ( from typing import (
Any, Any,
ClassVar, ClassVar,
@ -40,7 +39,6 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
TYPE_CHECKING, TYPE_CHECKING,
Tuple,
Type, Type,
TypeVar, TypeVar,
Union, Union,
@ -51,7 +49,7 @@ import weakref
import aiohttp import aiohttp
from .enums import RelationshipAction from .enums import RelationshipAction
from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordServerError, GatewayNotFound, InvalidArgument from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordServerError, InvalidArgument
from . import utils from . import utils
from .tracking import ContextProperties from .tracking import ContextProperties
from .utils import MISSING from .utils import MISSING
@ -66,12 +64,10 @@ if TYPE_CHECKING:
appinfo, appinfo,
audit_log, audit_log,
channel, channel,
components,
emoji, emoji,
embed, embed,
guild, guild,
integration, integration,
interactions,
invite, invite,
member, member,
message, message,
@ -79,11 +75,8 @@ if TYPE_CHECKING:
role, role,
user, user,
webhook, webhook,
channel,
widget, widget,
threads, threads,
voice,
snowflake,
sticker, sticker,
) )
from .types.snowflake import Snowflake, SnowflakeList from .types.snowflake import Snowflake, SnowflakeList
@ -446,7 +439,7 @@ class HTTPClient:
self.token = token self.token = token
self.ack_token = None self.ack_token = None
def get_me(self, with_analytics_token=True) -> user.User: def get_me(self, with_analytics_token=True) -> Response[user.User]:
params = { params = {
'with_analytics_token': str(with_analytics_token).lower() 'with_analytics_token': str(with_analytics_token).lower()
} }
@ -496,7 +489,7 @@ class HTTPClient:
return self.request(Route('PATCH', '/channels/{channel_id}', channel_id=channel_id), json=payload) return self.request(Route('PATCH', '/channels/{channel_id}', channel_id=channel_id), json=payload)
def get_private_channels(self) -> Response[List[channel.PrivateChannel]]: def get_private_channels(self) -> Response[List[Union[channel.DMChannel, channel.GroupDMChannel]]]:
return self.request(Route('GET', '/users/@me/channels')) return self.request(Route('GET', '/users/@me/channels'))
def start_private_message(self, user_id: Snowflake) -> Response[channel.DMChannel]: def start_private_message(self, user_id: Snowflake) -> Response[channel.DMChannel]:
@ -517,13 +510,13 @@ class HTTPClient:
tts: bool = False, tts: bool = False,
embed: Optional[embed.Embed] = None, embed: Optional[embed.Embed] = None,
embeds: Optional[List[embed.Embed]] = None, embeds: Optional[List[embed.Embed]] = None,
nonce: Optional[Union[int, str]] = None, nonce: Optional[Snowflake] = None,
allowed_mentions: Optional[message.AllowedMentions] = None, allowed_mentions: Optional[message.AllowedMentions] = None,
message_reference: Optional[message.MessageReference] = None, message_reference: Optional[message.MessageReference] = None,
stickers: Optional[List[sticker.StickerItem]] = None, stickers: Optional[List[sticker.StickerItem]] = None,
) -> Response[message.Message]: ) -> Response[message.Message]:
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
payload = {'tts': tts} payload: Dict[str, Any] = {'tts': tts}
if content: if content:
payload['content'] = content payload['content'] = content
if embed: if embed:
@ -553,13 +546,13 @@ class HTTPClient:
tts: bool = False, tts: bool = False,
embed: Optional[embed.Embed] = None, embed: Optional[embed.Embed] = None,
embeds: Optional[Iterable[Optional[embed.Embed]]] = None, embeds: Optional[Iterable[Optional[embed.Embed]]] = None,
nonce: Optional[str] = None, nonce: Optional[Snowflake] = None,
allowed_mentions: Optional[message.AllowedMentions] = None, allowed_mentions: Optional[message.AllowedMentions] = None,
message_reference: Optional[message.MessageReference] = None, message_reference: Optional[message.MessageReference] = None,
stickers: Optional[List[sticker.StickerItem]] = None, stickers: Optional[List[sticker.StickerItem]] = None,
) -> Response[message.Message]: ) -> Response[message.Message]:
form = [] form = []
payload = {'tts': tts} payload: Dict[str, Any] = {'tts': tts}
if content: if content:
payload['content'] = content payload['content'] = content
@ -605,7 +598,7 @@ class HTTPClient:
tts: bool = False, tts: bool = False,
embed: Optional[embed.Embed] = None, embed: Optional[embed.Embed] = None,
embeds: Optional[List[embed.Embed]] = None, embeds: Optional[List[embed.Embed]] = None,
nonce: Optional[str] = None, nonce: Optional[Snowflake] = None,
allowed_mentions: Optional[message.AllowedMentions] = None, allowed_mentions: Optional[message.AllowedMentions] = None,
message_reference: Optional[message.MessageReference] = None, message_reference: Optional[message.MessageReference] = None,
stickers: Optional[List[sticker.StickerItem]] = None, stickers: Optional[List[sticker.StickerItem]] = None,
@ -743,16 +736,17 @@ class HTTPClient:
) )
return self.request(r) return self.request(r)
async def get_message(self, channel_id: Snowflake, message_id: Snowflake) -> Response[message.Message]: async def get_message(self, channel_id: Snowflake, message_id: Snowflake) -> message.Message:
data = await self.logs_from(channel_id, 1, around=message_id) data = await self.logs_from(channel_id, 1, around=message_id)
try: try:
msg = data[0] msg = data[0]
except IndexError: except IndexError:
raise NotFound(_FakeResponse('Not Found', 404), 'message not found') raise NotFound(_FakeResponse('Not Found', 404), 'message not found')
if int(msg.get('id')) != message_id:
raise NotFound(_FakeResponse('Not Found', 404), 'message not found')
if int(msg.get('id')) == message_id: return msg
return msg
raise NotFound(_FakeResponse('Not Found', 404), 'message not found')
def get_channel(self, channel_id: Snowflake) -> Response[channel.Channel]: def get_channel(self, channel_id: Snowflake) -> Response[channel.Channel]:
return self.request(Route('GET', '/channels/{channel_id}', channel_id=channel_id)) return self.request(Route('GET', '/channels/{channel_id}', channel_id=channel_id))
@ -1132,8 +1126,8 @@ class HTTPClient:
after: Optional[Snowflake] = None, after: Optional[Snowflake] = None,
with_counts: bool = True with_counts: bool = True
) -> Response[List[guild.Guild]]: ) -> Response[List[guild.Guild]]:
params = { params: Dict[str, Snowflake] = {
'with_counts': with_counts 'with_counts': str(with_counts).lower()
} }
if limit and limit != 200: if limit and limit != 200:
params['limit'] = limit params['limit'] = limit
@ -1291,9 +1285,9 @@ class HTTPClient:
return self.request(Route('GET', '/stickers/{sticker_id}', sticker_id=sticker_id)) return self.request(Route('GET', '/stickers/{sticker_id}', sticker_id=sticker_id))
def list_premium_sticker_packs( def list_premium_sticker_packs(
self, country: str = 'US', locale: str = 'en-US', payment_source_id: int = MISSING self, country: str = 'US', locale: str = 'en-US', payment_source_id: Snowflake = MISSING
) -> Response[sticker.ListPremiumStickerPacks]: ) -> Response[sticker.ListPremiumStickerPacks]:
params = { params: Dict[str, Snowflake] = {
'country_code': country, 'country_code': country,
'locale': locale, 'locale': locale,
} }
@ -1469,7 +1463,7 @@ class HTTPClient:
action_type: Optional[AuditLogAction] = None, action_type: Optional[AuditLogAction] = None,
) -> Response[audit_log.AuditLog]: ) -> Response[audit_log.AuditLog]:
r = Route('GET', '/guilds/{guild_id}/audit-logs', guild_id=guild_id) r = Route('GET', '/guilds/{guild_id}/audit-logs', guild_id=guild_id)
params = { params: Dict[str, Any] = {
'limit': limit 'limit': limit
} }
if before: if before:
@ -1730,7 +1724,7 @@ class HTTPClient:
elif action == RelationshipAction.remove_pending_request: # Friends elif action == RelationshipAction.remove_pending_request: # Friends
props = ContextProperties._from_friends_page() props = ContextProperties._from_friends_page()
return self.request(r, context_properties=props) return self.request(r, context_properties=props) # type: ignore
def add_relationship( def add_relationship(
self, user_id: Snowflake, type: int = MISSING, *, action: RelationshipAction self, user_id: Snowflake, type: int = MISSING, *, action: RelationshipAction
@ -1756,7 +1750,7 @@ class HTTPClient:
ContextProperties._from_dm_channel() ContextProperties._from_dm_channel()
)) ))
kwargs = { kwargs = {
'context_properties': props 'context_properties': props # type: ignore
} }
if type: if type:
kwargs['json'] = {'type': type} kwargs['json'] = {'type': type}

Loading…
Cancel
Save