From 750ba88f2c9d1d51ecf3accdd325bd83aa8ae95b Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sun, 4 Jul 2021 07:55:20 -0400 Subject: [PATCH] Fix typing errors with Client --- discord/activity.py | 3 +++ discord/client.py | 35 ++++++++++++++++++++++------------- discord/http.py | 10 ++++++---- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/discord/activity.py b/discord/activity.py index d0a1af9a7..cbe1552f6 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -133,6 +133,9 @@ class BaseActivity: if self._created_at is not None: return datetime.datetime.utcfromtimestamp(self._created_at / 1000).replace(tzinfo=datetime.timezone.utc) + def to_dict(self) -> ActivityPayload: + raise NotImplementedError + class Activity(BaseActivity): """Represents an activity in Discord. diff --git a/discord/client.py b/discord/client.py index 24ceb31b7..d423fcfff 100644 --- a/discord/client.py +++ b/discord/client.py @@ -46,11 +46,12 @@ from .errors import * from .enums import Status, VoiceRegion from .flags import ApplicationFlags, Intents from .gateway import * -from .activity import BaseActivity, create_activity +from .activity import ActivityTypes, BaseActivity, create_activity from .voice_client import VoiceClient from .http import HTTPClient from .state import ConnectionState from . import utils +from .utils import MISSING from .object import Object from .backoff import ExponentialBackoff from .webhook import Webhook @@ -649,14 +650,14 @@ class Client: return self._closed @property - def activity(self) -> Optional[BaseActivity]: + def activity(self) -> Optional[ActivityTypes]: """Optional[:class:`.BaseActivity`]: The activity being used upon logging in. """ return create_activity(self._connection._activity) @activity.setter - def activity(self, value: Optional[BaseActivity]) -> None: + def activity(self, value: Optional[ActivityTypes]) -> None: if value is None: self._connection._activity = None elif isinstance(value, BaseActivity): @@ -1029,7 +1030,7 @@ class Client: limit: Optional[int] = 100, before: SnowflakeTime = None, after: SnowflakeTime = None - ) -> List[Guild]: + ) -> GuildIterator: """Retrieves an :class:`.AsyncIterator` that enables receiving your guilds. .. note:: @@ -1144,7 +1145,14 @@ class Client: data = await self.http.get_guild(guild_id) return Guild(data=data, state=self._connection) - async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None, *, code: str = None) -> Guild: + async def create_guild( + self, + *, + name: str, + region: Union[VoiceRegion, str] = VoiceRegion.us_west, + icon: bytes = MISSING, + code: str = MISSING, + ) -> Guild: """|coro| Creates a :class:`.Guild`. @@ -1158,10 +1166,10 @@ class Client: region: :class:`.VoiceRegion` The region for the voice communication server. Defaults to :attr:`.VoiceRegion.us_west`. - icon: :class:`bytes` + icon: Optional[:class:`bytes`] The :term:`py:bytes-like object` representing the icon. See :meth:`.ClientUser.edit` for more details on what is expected. - code: Optional[:class:`str`] + code: :class:`str` The code for a template to create the guild with. .. versionadded:: 1.4 @@ -1179,16 +1187,17 @@ class Client: The guild created. This is not the same guild that is added to cache. """ - if icon is not None: - icon = utils._bytes_to_base64_data(icon) + if icon is not MISSING: + icon_base64 = utils._bytes_to_base64_data(icon) + else: + icon_base64 = None - region = region or VoiceRegion.us_west - region_value = region.value + region_value = str(region) if code: - data = await self.http.create_from_template(code, name, region_value, icon) + data = await self.http.create_from_template(code, name, region_value, icon_base64) else: - data = await self.http.create_guild(name, region_value, icon) + data = await self.http.create_guild(name, region_value, icon_base64) return Guild(data=data, state=self._connection) async def fetch_stage_instance(self, channel_id: int) -> StageInstance: diff --git a/discord/http.py b/discord/http.py index 0ba8c2ce1..cfc71e22c 100644 --- a/discord/http.py +++ b/discord/http.py @@ -1032,12 +1032,13 @@ class HTTPClient: def delete_guild(self, guild_id: Snowflake) -> Response[None]: return self.request(Route('DELETE', '/guilds/{guild_id}', guild_id=guild_id)) - def create_guild(self, name: str, region: str, icon: bytes) -> Response[guild.Guild]: + def create_guild(self, name: str, region: str, icon: Optional[str]) -> Response[guild.Guild]: payload = { 'name': name, - 'icon': icon, 'region': region, } + if icon: + payload['icon'] = icon return self.request(Route('POST', '/guilds'), json=payload) @@ -1093,12 +1094,13 @@ class HTTPClient: def delete_template(self, guild_id: Snowflake, code: str) -> Response[None]: return self.request(Route('DELETE', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code)) - def create_from_template(self, code: str, name: str, region: str, icon: bytes) -> Response[guild.Guild]: + def create_from_template(self, code: str, name: str, region: str, icon: Optional[str]) -> Response[guild.Guild]: payload = { 'name': name, - 'icon': icon, 'region': region, } + if icon: + payload['icon'] = icon return self.request(Route('POST', '/guilds/templates/{code}', code=code), json=payload) def get_bans(self, guild_id: Snowflake) -> Response[List[guild.Ban]]: