Browse Source

Fix typing errors

pull/10109/head
dolfies 4 years ago
parent
commit
cbab87d287
  1. 46
      discord/http.py

46
discord/http.py

@ -29,7 +29,6 @@ from base64 import b64encode
import json
import logging
from random import choice, getrandbits
import sys
from typing import (
Any,
ClassVar,
@ -40,7 +39,6 @@ from typing import (
Optional,
Sequence,
TYPE_CHECKING,
Tuple,
Type,
TypeVar,
Union,
@ -51,7 +49,7 @@ import weakref
import aiohttp
from .enums import RelationshipAction
from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordServerError, GatewayNotFound, InvalidArgument
from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordServerError, InvalidArgument
from . import utils
from .tracking import ContextProperties
from .utils import MISSING
@ -66,12 +64,10 @@ if TYPE_CHECKING:
appinfo,
audit_log,
channel,
components,
emoji,
embed,
guild,
integration,
interactions,
invite,
member,
message,
@ -79,11 +75,8 @@ if TYPE_CHECKING:
role,
user,
webhook,
channel,
widget,
threads,
voice,
snowflake,
sticker,
)
from .types.snowflake import Snowflake, SnowflakeList
@ -446,7 +439,7 @@ class HTTPClient:
self.token = token
self.ack_token = None
def get_me(self, with_analytics_token=True) -> user.User:
def get_me(self, with_analytics_token=True) -> Response[user.User]:
params = {
'with_analytics_token': str(with_analytics_token).lower()
}
@ -496,7 +489,7 @@ class HTTPClient:
return self.request(Route('PATCH', '/channels/{channel_id}', channel_id=channel_id), json=payload)
def get_private_channels(self) -> Response[List[channel.PrivateChannel]]:
def get_private_channels(self) -> Response[List[Union[channel.DMChannel, channel.GroupDMChannel]]]:
return self.request(Route('GET', '/users/@me/channels'))
def start_private_message(self, user_id: Snowflake) -> Response[channel.DMChannel]:
@ -517,13 +510,13 @@ class HTTPClient:
tts: bool = False,
embed: Optional[embed.Embed] = None,
embeds: Optional[List[embed.Embed]] = None,
nonce: Optional[Union[int, str]] = None,
nonce: Optional[Snowflake] = None,
allowed_mentions: Optional[message.AllowedMentions] = None,
message_reference: Optional[message.MessageReference] = None,
stickers: Optional[List[sticker.StickerItem]] = None,
) -> Response[message.Message]:
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
payload = {'tts': tts}
payload: Dict[str, Any] = {'tts': tts}
if content:
payload['content'] = content
if embed:
@ -553,13 +546,13 @@ class HTTPClient:
tts: bool = False,
embed: Optional[embed.Embed] = None,
embeds: Optional[Iterable[Optional[embed.Embed]]] = None,
nonce: Optional[str] = None,
nonce: Optional[Snowflake] = None,
allowed_mentions: Optional[message.AllowedMentions] = None,
message_reference: Optional[message.MessageReference] = None,
stickers: Optional[List[sticker.StickerItem]] = None,
) -> Response[message.Message]:
form = []
payload = {'tts': tts}
payload: Dict[str, Any] = {'tts': tts}
if content:
payload['content'] = content
@ -605,7 +598,7 @@ class HTTPClient:
tts: bool = False,
embed: Optional[embed.Embed] = None,
embeds: Optional[List[embed.Embed]] = None,
nonce: Optional[str] = None,
nonce: Optional[Snowflake] = None,
allowed_mentions: Optional[message.AllowedMentions] = None,
message_reference: Optional[message.MessageReference] = None,
stickers: Optional[List[sticker.StickerItem]] = None,
@ -743,16 +736,17 @@ class HTTPClient:
)
return self.request(r)
async def get_message(self, channel_id: Snowflake, message_id: Snowflake) -> Response[message.Message]:
async def get_message(self, channel_id: Snowflake, message_id: Snowflake) -> message.Message:
data = await self.logs_from(channel_id, 1, around=message_id)
try:
msg = data[0]
except IndexError:
raise NotFound(_FakeResponse('Not Found', 404), 'message not found')
if int(msg.get('id')) != message_id:
raise NotFound(_FakeResponse('Not Found', 404), 'message not found')
if int(msg.get('id')) == message_id:
return msg
raise NotFound(_FakeResponse('Not Found', 404), 'message not found')
return msg
def get_channel(self, channel_id: Snowflake) -> Response[channel.Channel]:
return self.request(Route('GET', '/channels/{channel_id}', channel_id=channel_id))
@ -1132,8 +1126,8 @@ class HTTPClient:
after: Optional[Snowflake] = None,
with_counts: bool = True
) -> Response[List[guild.Guild]]:
params = {
'with_counts': with_counts
params: Dict[str, Snowflake] = {
'with_counts': str(with_counts).lower()
}
if limit and limit != 200:
params['limit'] = limit
@ -1291,9 +1285,9 @@ class HTTPClient:
return self.request(Route('GET', '/stickers/{sticker_id}', sticker_id=sticker_id))
def list_premium_sticker_packs(
self, country: str = 'US', locale: str = 'en-US', payment_source_id: int = MISSING
self, country: str = 'US', locale: str = 'en-US', payment_source_id: Snowflake = MISSING
) -> Response[sticker.ListPremiumStickerPacks]:
params = {
params: Dict[str, Snowflake] = {
'country_code': country,
'locale': locale,
}
@ -1469,7 +1463,7 @@ class HTTPClient:
action_type: Optional[AuditLogAction] = None,
) -> Response[audit_log.AuditLog]:
r = Route('GET', '/guilds/{guild_id}/audit-logs', guild_id=guild_id)
params = {
params: Dict[str, Any] = {
'limit': limit
}
if before:
@ -1730,7 +1724,7 @@ class HTTPClient:
elif action == RelationshipAction.remove_pending_request: # Friends
props = ContextProperties._from_friends_page()
return self.request(r, context_properties=props)
return self.request(r, context_properties=props) # type: ignore
def add_relationship(
self, user_id: Snowflake, type: int = MISSING, *, action: RelationshipAction
@ -1756,7 +1750,7 @@ class HTTPClient:
ContextProperties._from_dm_channel()
))
kwargs = {
'context_properties': props
'context_properties': props # type: ignore
}
if type:
kwargs['json'] = {'type': type}

Loading…
Cancel
Save