diff --git a/discord/channel.py b/discord/channel.py index 9b1825f90..c1b957527 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -48,7 +48,6 @@ __all__ = ( 'CategoryChannel', 'StoreChannel', 'GroupChannel', - '_channel_factory', ) if TYPE_CHECKING: @@ -1834,18 +1833,19 @@ class GroupChannel(discord.abc.Messageable, Hashable): await self._state.http.leave_group(self.id) -def _channel_factory(channel_type): - value = try_enum(ChannelType, channel_type) +def _coerce_channel_type(value: Union[ChannelType, int]) -> ChannelType: + if isinstance(value, ChannelType): + return value + return try_enum(ChannelType, value) + +def _guild_channel_factory(channel_type: Union[ChannelType, int]): + value = _coerce_channel_type(channel_type) if value is ChannelType.text: return TextChannel, value elif value is ChannelType.voice: return VoiceChannel, value - elif value is ChannelType.private: - return DMChannel, value elif value is ChannelType.category: return CategoryChannel, value - elif value is ChannelType.group: - return GroupChannel, value elif value is ChannelType.news: return TextChannel, value elif value is ChannelType.store: @@ -1854,3 +1854,12 @@ def _channel_factory(channel_type): return StageChannel, value else: return None, value + +def _channel_factory(channel_type: Union[ChannelType, int]): + cls, value = _guild_channel_factory(channel_type) + if value is ChannelType.private: + return DMChannel, value + elif value is ChannelType.group: + return GroupChannel, value + else: + return cls, value diff --git a/discord/guild.py b/discord/guild.py index e32c2fcf7..f94712034 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -25,8 +25,21 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations import copy -from collections import namedtuple -from typing import Dict, List, Set, Literal, Optional, TYPE_CHECKING, Union, overload +from typing import ( + Any, + ClassVar, + Dict, + List, + NamedTuple, + Sequence, + Set, + Literal, + Optional, + TYPE_CHECKING, + Tuple, + Union, + overload, +) from . import utils, abc from .role import Role @@ -37,7 +50,18 @@ from .permissions import PermissionOverwrite from .colour import Colour from .errors import InvalidArgument, ClientException from .channel import * -from .enums import AuditLogAction, VideoQualityMode, VoiceRegion, ChannelType, try_enum, VerificationLevel, ContentFilter, NotificationLevel, NSFWLevel +from .channel import _guild_channel_factory +from .enums import ( + AuditLogAction, + VideoQualityMode, + VoiceRegion, + ChannelType, + try_enum, + VerificationLevel, + ContentFilter, + NotificationLevel, + NSFWLevel, +) from .mixins import Hashable from .user import User from .invite import Invite @@ -53,19 +77,38 @@ __all__ = ( 'Guild', ) +MISSING = utils.MISSING + if TYPE_CHECKING: - from .abc import SnowflakeTime - from .types.guild import ( - Ban as BanPayload + from .abc import Snowflake, SnowflakeTime + from .types.guild import Ban as BanPayload, Guild as GuildPayload, MFALevel + from .types.threads import ( + Thread as ThreadPayload, ) + from .types.voice import GuildVoiceState from .permissions import Permissions - from .channel import VoiceChannel, StageChannel + from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel from .template import Template + from .webhook import Webhook + from .state import ConnectionState + from .voice_client import VoiceProtocol + + import datetime VocalGuildChannel = Union[VoiceChannel, StageChannel] + GuildChannel = Union[VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel] + ByCategoryItem = Tuple[Optional[CategoryChannel], List[GuildChannel]] + + +class BanEntry(NamedTuple): + reason: Optional[str] + user: User + -BanEntry = namedtuple('BanEntry', 'reason user') -_GuildLimit = namedtuple('_GuildLimit', 'emoji bitrate filesize') +class _GuildLimit(NamedTuple): + emoji: int + bitrate: float + filesize: int class Guild(Hashable): @@ -174,18 +217,48 @@ class Guild(Hashable): .. versionadded:: 2.0 """ - __slots__ = ('afk_timeout', 'afk_channel', '_members', '_channels', '_icon', - 'name', 'id', 'unavailable', '_banner', 'region', '_state', - '_roles', '_member_count', '_large', - 'owner_id', 'mfa_level', 'emojis', 'features', - 'verification_level', 'explicit_content_filter', '_splash', - '_voice_states', '_system_channel_id', 'default_notifications', - 'description', 'max_presences', 'max_members', 'max_video_channel_users', - 'premium_tier', 'premium_subscription_count', '_system_channel_flags', - 'preferred_locale', '_discovery_splash', '_rules_channel_id', - '_public_updates_channel_id', '_stage_instances', 'nsfw_level', '_threads') - - _PREMIUM_GUILD_LIMITS = { + __slots__ = ( + 'afk_timeout', + 'afk_channel', + 'name', + 'id', + 'unavailable', + 'region', + 'owner_id', + 'mfa_level', + 'emojis', + 'features', + 'verification_level', + 'explicit_content_filter', + 'default_notifications', + 'description', + 'max_presences', + 'max_members', + 'max_video_channel_users', + 'premium_tier', + 'premium_subscription_count', + 'preferred_locale', + 'nsfw_level', + '_members', + '_channels', + '_icon', + '_banner', + '_state', + '_roles', + '_member_count', + '_large', + '_splash', + '_voice_states', + '_system_channel_id', + '_system_channel_flags', + '_discovery_splash', + '_rules_channel_id', + '_public_updates_channel_id', + '_stage_instances', + '_threads', + ) + + _PREMIUM_GUILD_LIMITS: ClassVar[Dict[Optional[int], _GuildLimit]] = { None: _GuildLimit(emoji=50, bitrate=96e3, filesize=8388608), 0: _GuildLimit(emoji=50, bitrate=96e3, filesize=8388608), 1: _GuildLimit(emoji=100, bitrate=128e3, filesize=8388608), @@ -193,7 +266,50 @@ class Guild(Hashable): 3: _GuildLimit(emoji=250, bitrate=384e3, filesize=104857600), } - def __init__(self, *, data, state): + # The attributes are typed here due to the usage of late init + + name: str + region: VoiceRegion + verification_level: VerificationLevel + default_notifications: NotificationLevel + explicit_content_filter: ContentFilter + afk_timeout: int + unavailable: bool + id: int + mfa_level: MFALevel + emojis: Tuple[Emoji, ...] + features: List[str] + description: Optional[str] + max_presences: Optional[int] + max_members: Optional[int] + max_video_channel_users: Optional[int] + premium_tier: int + premium_subscription_count: int + preferred_locale: Optional[str] + nsfw_level: NSFWLevel + owner_id: Optional[int] + afk_channel: Optional[VocalGuildChannel] + + # These are private + + _channels: Dict[int, GuildChannel] + _members: Dict[int, Member] + _voice_states: Dict[int, VoiceState] + _threads: Dict[int, Thread] + _state: ConnectionState + _icon: Optional[str] + _banner: Optional[str] + _roles: Dict[int, Role] + _splash: Optional[str] + _system_channel_id: Optional[int] + _system_channel_flags: int + _discovery_splash: Optional[str] + _rules_channel_id: Optional[int] + _public_updates_channel_id: Optional[int] + _stage_instances: Dict[int, StageInstance] + _large: Optional[bool] + + def __init__(self, *, data: GuildPayload, state: ConnectionState): self._channels = {} self._members = {} self._voice_states = {} @@ -201,36 +317,36 @@ class Guild(Hashable): self._state = state self._from_data(data) - def _add_channel(self, channel): + def _add_channel(self, channel: GuildChannel, /) -> None: self._channels[channel.id] = channel - def _remove_channel(self, channel): + def _remove_channel(self, channel: Snowflake, /) -> None: self._channels.pop(channel.id, None) - def _voice_state_for(self, user_id): + def _voice_state_for(self, user_id: int, /) -> Optional[VoiceState]: return self._voice_states.get(user_id) - def _add_member(self, member): + def _add_member(self, member: Member, /) -> None: self._members[member.id] = member - def _store_thread(self, payload) -> Thread: + def _store_thread(self, payload: ThreadPayload, /) -> Thread: thread = Thread(guild=self, data=payload) self._threads[thread.id] = thread return thread - def _remove_member(self, member): + def _remove_member(self, member: Snowflake, /) -> None: self._members.pop(member.id, None) - def _add_thread(self, thread): + def _add_thread(self, thread: Thread, /) -> None: self._threads[thread.id] = thread - def _remove_thread(self, thread): + def _remove_thread(self, thread: Snowflake, /) -> None: self._threads.pop(thread.id, None) - def _clear_threads(self): + def _clear_threads(self) -> None: self._threads.clear() - def _remove_threads_by_channel(self, channel_id: int): + def _remove_threads_by_channel(self, channel_id: int) -> None: to_remove = [k for k, t in self._threads.items() if t.parent_id == channel_id] for k in to_remove: del self._threads[k] @@ -241,10 +357,10 @@ class Guild(Hashable): del self._threads[k] return to_remove - def __str__(self): + def __str__(self) -> str: return self.name or '' - def __repr__(self): + def __repr__(self) -> str: attrs = ( ('id', self.id), ('name', self.name), @@ -255,7 +371,7 @@ class Guild(Hashable): inner = ' '.join('%s=%r' % t for t in attrs) return f'' - def _update_voice_state(self, data, channel_id): + def _update_voice_state(self, data: GuildVoiceState, channel_id: int) -> Tuple[Optional[Member], VoiceState, VoiceState]: user_id = int(data['user_id']) channel = self.get_channel(channel_id) try: @@ -282,18 +398,18 @@ class Guild(Hashable): return member, before, after - def _add_role(self, role): + def _add_role(self, role: Role, /) -> None: # roles get added to the bottom (position 1, pos 0 is @everyone) # so since self.roles has the @everyone role, we can't increment # its position because it's stuck at position 0. Luckily x += False # is equivalent to adding 0. So we cast the position to a bool and # increment it. for r in self._roles.values(): - r.position += (not r.is_default()) + r.position += not r.is_default() self._roles[role.id] = role - def _remove_role(self, role_id): + def _remove_role(self, role_id: int, /) -> Role: # this raises KeyError if it fails.. role = self._roles.pop(role_id) @@ -305,7 +421,7 @@ class Guild(Hashable): return role - def _from_data(self, guild): + def _from_data(self, guild: GuildPayload) -> None: # according to Stan, this is always available even if the guild is unavailable # I don't have this guarantee when someone updates the guild. member_count = guild.get('member_count', None) @@ -323,7 +439,7 @@ class Guild(Hashable): self.unavailable = guild.get('unavailable', False) self.id = int(guild['id']) self._roles = {} - state = self._state # speed up attribute access + state = self._state # speed up attribute access for r in guild.get('roles', []): role = Role(guild=self, data=r, state=state) self._roles[role.id] = role @@ -362,12 +478,13 @@ class Guild(Hashable): self._large = None if member_count is None else self._member_count >= 250 self.owner_id = utils._get_as_snowflake(guild, 'owner_id') - self.afk_channel = self.get_channel(utils._get_as_snowflake(guild, 'afk_channel_id')) + self.afk_channel = self.get_channel(utils._get_as_snowflake(guild, 'afk_channel_id')) # type: ignore for obj in guild.get('voice_states', []): self._update_voice_state(obj, int(obj['channel_id'])) - def _sync(self, data): + # TODO: refactor/remove? + def _sync(self, data: GuildPayload) -> None: try: self._large = data['large'] except KeyError: @@ -383,7 +500,7 @@ class Guild(Hashable): if 'channels' in data: channels = data['channels'] for c in channels: - factory, ch_type = _channel_factory(c['type']) + factory, ch_type = _guild_channel_factory(c['type']) if factory: self._add_channel(factory(guild=self, data=c, state=self._state)) @@ -393,12 +510,12 @@ class Guild(Hashable): self._add_thread(Thread(guild=self, data=thread)) @property - def channels(self): + def channels(self) -> List[GuildChannel]: """List[:class:`abc.GuildChannel`]: A list of channels that belongs to this guild.""" return list(self._channels.values()) @property - def threads(self): + def threads(self) -> List[Thread]: """List[:class:`Thread`]: A list of threads that you have permission to view. .. versionadded:: 2.0 @@ -406,7 +523,7 @@ class Guild(Hashable): return list(self._threads.values()) @property - def large(self): + def large(self) -> bool: """:class:`bool`: Indicates if the guild is a 'large' guild. A large guild is defined as having more than ``large_threshold`` count @@ -420,7 +537,7 @@ class Guild(Hashable): return self._large @property - def voice_channels(self): + def voice_channels(self) -> List[VoiceChannel]: """List[:class:`VoiceChannel`]: A list of voice channels that belongs to this guild. This is sorted by the position and are in UI order from top to bottom. @@ -430,7 +547,7 @@ class Guild(Hashable): return r @property - def stage_channels(self): + def stage_channels(self) -> List[StageChannel]: """List[:class:`StageChannel`]: A list of stage channels that belongs to this guild. .. versionadded:: 1.7 @@ -442,20 +559,21 @@ class Guild(Hashable): return r @property - def me(self): + def me(self) -> Member: """:class:`Member`: Similar to :attr:`Client.user` except an instance of :class:`Member`. This is essentially used to get the member version of yourself. """ self_id = self._state.user.id - return self.get_member(self_id) + # The self member is *always* cached + return self.get_member(self_id) # type: ignore @property - def voice_client(self): + def voice_client(self) -> Optional[VoiceProtocol]: """Optional[:class:`VoiceProtocol`]: Returns the :class:`VoiceProtocol` associated with this guild, if any.""" return self._state._get_voice_client(self.id) @property - def text_channels(self): + def text_channels(self) -> List[TextChannel]: """List[:class:`TextChannel`]: A list of text channels that belongs to this guild. This is sorted by the position and are in UI order from top to bottom. @@ -465,7 +583,7 @@ class Guild(Hashable): return r @property - def categories(self): + def categories(self) -> List[CategoryChannel]: """List[:class:`CategoryChannel`]: A list of categories that belongs to this guild. This is sorted by the position and are in UI order from top to bottom. @@ -474,7 +592,7 @@ class Guild(Hashable): r.sort(key=lambda c: (c.position, c.id)) return r - def by_category(self): + def by_category(self) -> List[ByCategoryItem]: """Returns every :class:`CategoryChannel` and their associated channels. These channels and categories are sorted in the official Discord UI order. @@ -487,7 +605,7 @@ class Guild(Hashable): List[Tuple[Optional[:class:`CategoryChannel`], List[:class:`abc.GuildChannel`]]]: The categories and their associated channels. """ - grouped = {} + grouped: Dict[Optional[int], List[GuildChannel]] = {} for channel in self._channels.values(): if isinstance(channel, CategoryChannel): grouped.setdefault(channel.id, []) @@ -498,18 +616,18 @@ class Guild(Hashable): except KeyError: grouped[channel.category_id] = [channel] - def key(t): + def key(t: ByCategoryItem) -> Tuple[Tuple[int, int], List[GuildChannel]]: k, v = t return ((k.position, k.id) if k else (-1, -1), v) _get = self._channels.get - as_list = [(_get(k), v) for k, v in grouped.items()] + as_list: List[ByCategoryItem] = [(_get(k), v) for k, v in grouped.items()] # type: ignore as_list.sort(key=key) for _, channels in as_list: channels.sort(key=lambda c: (c._sorting_bucket, c.position, c.id)) return as_list - def get_channel(self, channel_id): + def get_channel(self, channel_id: int, /) -> Optional[GuildChannel]: """Returns a channel with the given ID. .. note:: @@ -528,7 +646,7 @@ class Guild(Hashable): """ return self._channels.get(channel_id) - def get_thread(self, thread_id): + def get_thread(self, thread_id: int, /) -> Optional[Thread]: """Returns a thread with the given ID. .. versionadded:: 2.0 @@ -546,21 +664,21 @@ class Guild(Hashable): return self._threads.get(thread_id) @property - def system_channel(self): + def system_channel(self) -> Optional[TextChannel]: """Optional[:class:`TextChannel`]: Returns the guild's channel used for system messages. If no channel is set, then this returns ``None``. """ channel_id = self._system_channel_id - return channel_id and self._channels.get(channel_id) + return channel_id and self._channels.get(channel_id) # type: ignore @property - def system_channel_flags(self): + def system_channel_flags(self) -> SystemChannelFlags: """:class:`SystemChannelFlags`: Returns the guild's system channel settings.""" return SystemChannelFlags._from_value(self._system_channel_flags) @property - def rules_channel(self): + def rules_channel(self) -> Optional[TextChannel]: """Optional[:class:`TextChannel`]: Return's the guild's channel used for the rules. The guild must be a Community guild. @@ -569,10 +687,10 @@ class Guild(Hashable): .. versionadded:: 1.3 """ channel_id = self._rules_channel_id - return channel_id and self._channels.get(channel_id) + return channel_id and self._channels.get(channel_id) # type: ignore @property - def public_updates_channel(self): + def public_updates_channel(self) -> Optional[TextChannel]: """Optional[:class:`TextChannel`]: Return's the guild's channel where admins and moderators of the guilds receive notices from Discord. The guild must be a Community guild. @@ -582,31 +700,31 @@ class Guild(Hashable): .. versionadded:: 1.4 """ channel_id = self._public_updates_channel_id - return channel_id and self._channels.get(channel_id) + return channel_id and self._channels.get(channel_id) # type: ignore @property - def emoji_limit(self): + def emoji_limit(self) -> int: """:class:`int`: The maximum number of emoji slots this guild has.""" more_emoji = 200 if 'MORE_EMOJI' in self.features else 50 return max(more_emoji, self._PREMIUM_GUILD_LIMITS[self.premium_tier].emoji) @property - def bitrate_limit(self): + def bitrate_limit(self) -> float: """:class:`float`: The maximum bitrate for voice channels this guild can have.""" vip_guild = self._PREMIUM_GUILD_LIMITS[1].bitrate if 'VIP_REGIONS' in self.features else 96e3 return max(vip_guild, self._PREMIUM_GUILD_LIMITS[self.premium_tier].bitrate) @property - def filesize_limit(self): + def filesize_limit(self) -> int: """:class:`int`: The maximum number of bytes files can have when uploaded to this guild.""" return self._PREMIUM_GUILD_LIMITS[self.premium_tier].filesize @property - def members(self): + def members(self) -> List[Member]: """List[:class:`Member`]: A list of members that belong to this guild.""" return list(self._members.values()) - def get_member(self, user_id): + def get_member(self, user_id: int) -> Optional[Member]: """Returns a member with the given ID. Parameters @@ -622,12 +740,12 @@ class Guild(Hashable): return self._members.get(user_id) @property - def premium_subscribers(self): + def premium_subscribers(self) -> List[Member]: """List[:class:`Member`]: A list of members who have "boosted" this guild.""" return [member for member in self.members if member.premium_since is not None] @property - def roles(self): + def roles(self) -> List[Role]: """List[:class:`Role`]: Returns a :class:`list` of the guild's roles in hierarchy order. The first element of this list will be the lowest role in the @@ -635,7 +753,7 @@ class Guild(Hashable): """ return sorted(self._roles.values()) - def get_role(self, role_id): + def get_role(self, role_id: int, /) -> Optional[Role]: """Returns a role with the given ID. Parameters @@ -651,12 +769,13 @@ class Guild(Hashable): return self._roles.get(role_id) @property - def default_role(self): + def default_role(self) -> Role: """:class:`Role`: Gets the @everyone role that all members have by default.""" - return self.get_role(self.id) + # The @everyone role is *always* given + return self.get_role(self.id) # type: ignore @property - def premium_subscriber_role(self): + def premium_subscriber_role(self) -> Optional[Role]: """Optional[:class:`Role`]: Gets the premium subscriber role, AKA "boost" role, in this guild. .. versionadded:: 1.6 @@ -667,7 +786,7 @@ class Guild(Hashable): return None @property - def self_role(self): + def self_role(self) -> Optional[Role]: """Optional[:class:`Role`]: Gets the role associated with this client's user, if any. .. versionadded:: 1.6 @@ -688,7 +807,7 @@ class Guild(Hashable): """ return list(self._stage_instances.values()) - def get_stage_instance(self, stage_instance_id: int) -> Optional[StageInstance]: + def get_stage_instance(self, stage_instance_id: int, /) -> Optional[StageInstance]: """Returns a stage instance with the given ID. .. versionadded:: 2.0 @@ -706,40 +825,40 @@ class Guild(Hashable): return self._stage_instances.get(stage_instance_id) @property - def owner(self): + def owner(self) -> Optional[Member]: """Optional[:class:`Member`]: The member that owns the guild.""" - return self.get_member(self.owner_id) + return self.get_member(self.owner_id) # type: ignore @property - def icon(self): + def icon(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the guild's icon asset, if available.""" if self._icon is None: return None return Asset._from_guild_icon(self._state, self.id, self._icon) @property - def banner(self): + def banner(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the guild's banner asset, if available.""" if self._banner is None: return None return Asset._from_guild_image(self._state, self.id, self._banner, path='banners') @property - def splash(self): + def splash(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available.""" if self._splash is None: return None return Asset._from_guild_image(self._state, self.id, self._splash, path='splashes') @property - def discovery_splash(self): + def discovery_splash(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the guild's discovery splash asset, if available.""" if self._discovery_splash is None: return None return Asset._from_guild_image(self._state, self.id, self._discovery_splash, path='discovery-splashes') @property - def member_count(self): + def member_count(self) -> int: """:class:`int`: Returns the true member count regardless of it being loaded fully or not. .. warning:: @@ -751,7 +870,7 @@ class Guild(Hashable): return self._member_count @property - def chunked(self): + def chunked(self) -> bool: """:class:`bool`: Returns a boolean indicating if the guild is "chunked". A chunked guild means that :attr:`member_count` is equal to the @@ -766,19 +885,19 @@ class Guild(Hashable): return count == len(self._members) @property - def shard_id(self): + def shard_id(self) -> int: """:class:`int`: Returns the shard ID for this guild if applicable.""" count = self._state.shard_count if count is None: - return None + return 0 return (self.id >> 22) % count @property - def created_at(self): + def created_at(self) -> datetime.datetime: """:class:`datetime.datetime`: Returns the guild's creation time in UTC.""" return utils.snowflake_time(self.id) - def get_member_named(self, name): + def get_member_named(self, name: str, /) -> Optional[Member]: """Returns the first member found that matches the name provided. The name can have an optional discriminator argument, e.g. "Jake#0001" @@ -819,13 +938,20 @@ class Guild(Hashable): if result is not None: return result - def pred(m): + def pred(m: Member) -> bool: return m.nick == name or m.name == name return utils.find(pred, members) - def _create_channel(self, name, overwrites, channel_type, category=None, **options): - if overwrites is None: + def _create_channel( + self, + name: str, + channel_type: ChannelType, + overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING, + category: Optional[Snowflake] = None, + **options: Any, + ): + if overwrites is MISSING: overwrites = {} elif not isinstance(overwrites, dict): raise InvalidArgument('overwrites parameter expects a dict.') @@ -836,11 +962,7 @@ class Guild(Hashable): raise InvalidArgument(f'Expected PermissionOverwrite received {perm.__class__.__name__}') allow, deny = perm.pair() - payload = { - 'allow': allow.value, - 'deny': deny.value, - 'id': target.id - } + payload = {'allow': allow.value, 'deny': deny.value, 'id': target.id} if isinstance(target, Role): payload['type'] = abc._Overwrites.ROLE @@ -849,45 +971,23 @@ class Guild(Hashable): perms.append(payload) - try: - options['rate_limit_per_user'] = options.pop('slowmode_delay') - except KeyError: - pass - - try: - rtc_region = options.pop('rtc_region') - except KeyError: - pass - else: - options['rtc_region'] = None if rtc_region is None else str(rtc_region) - parent_id = category.id if category else None - return self._state.http.create_channel(self.id, channel_type.value, name=name, parent_id=parent_id, - permission_overwrites=perms, **options) + return self._state.http.create_channel( + self.id, channel_type.value, name=name, parent_id=parent_id, permission_overwrites=perms, **options + ) - @overload async def create_text_channel( self, name: str, *, - reason: Optional[str] = ..., - category: Optional[CategoryChannel], - position: int = ..., - topic: Optional[str] = ..., - slowmode_delay: int = ..., - nsfw: bool = ..., - overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., - ) -> TextChannel: - ... - - @overload - async def create_text_channel( - self, - name: str + reason: Optional[str] = None, + category: Optional[CategoryChannel] = None, + position: int = MISSING, + topic: str = MISSING, + slowmode_delay: int = MISSING, + nsfw: bool = MISSING, + overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING, ) -> TextChannel: - ... - - async def create_text_channel(self, name, *, overwrites=None, category=None, reason=None, **options): """|coro| Creates a :class:`TextChannel` for the guild. @@ -930,7 +1030,7 @@ class Guild(Hashable): ----------- name: :class:`str` The channel's name. - overwrites + overwrites: Dict[Union[:class:`Role`, :class:`Member`], :class:`PermissionOverwrite`] A :class:`dict` of target (either a role or a member) to :class:`PermissionOverwrite` to apply upon creation of a channel. Useful for creating secret channels. @@ -941,7 +1041,7 @@ class Guild(Hashable): position: :class:`int` The position in the channel list. This is a number that starts at 0. e.g. the top channel is position 0. - topic: Optional[:class:`str`] + topic: :class:`str` The new channel's topic. slowmode_delay: :class:`int` Specifies the slowmode rate limit for user in this channel, in seconds. @@ -965,44 +1065,61 @@ class Guild(Hashable): :class:`TextChannel` The channel that was just created. """ - data = await self._create_channel(name, overwrites, ChannelType.text, category, reason=reason, **options) + + options = {} + if position is not MISSING: + options['position'] = position + + if topic is not MISSING: + options['topic'] = topic + + if slowmode_delay is not MISSING: + options['rate_limit_per_user'] = slowmode_delay + + if nsfw is not MISSING: + options['nsfw'] = nsfw + + data = await self._create_channel( + name, overwrites=overwrites, channel_type=ChannelType.text, category=category, reason=reason, **options + ) channel = TextChannel(state=self._state, guild=self, data=data) # temporarily add to the cache self._channels[channel.id] = channel return channel - @overload async def create_voice_channel( self, name: str, *, - reason: Optional[str] = ..., - category: Optional[CategoryChannel], - position: int = ..., - bitrate: int = ..., - user_limit: int = ..., - rtc_region: Optional[VoiceRegion] = ..., - voice_quality_mode: VideoQualityMode = ..., - overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., - ) -> VoiceChannel: - ... - - @overload - async def create_voice_channel( - self, - name: str + reason: Optional[str] = None, + category: Optional[CategoryChannel] = None, + position: int = MISSING, + bitrate: int = MISSING, + user_limit: int = MISSING, + rtc_region: Optional[VoiceRegion] = MISSING, + video_quality_mode: VideoQualityMode = MISSING, + overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING, ) -> VoiceChannel: - ... - - async def create_voice_channel(self, name, *, overwrites=None, category=None, reason=None, **options): """|coro| - This is similar to :meth:`create_text_channel` except makes a :class:`VoiceChannel` instead, in addition - to having the following new parameters. + This is similar to :meth:`create_text_channel` except makes a :class:`VoiceChannel` instead. Parameters ----------- + name: :class:`str` + The channel's name. + overwrites: Dict[Union[:class:`Role`, :class:`Member`], :class:`PermissionOverwrite`] + A :class:`dict` of target (either a role or a member) to + :class:`PermissionOverwrite` to apply upon creation of a channel. + Useful for creating secret channels. + category: Optional[:class:`CategoryChannel`] + The category to place the newly created channel under. + The permissions will be automatically synced to category if no + overwrites are provided. + position: :class:`int` + The position in the channel list. This is a number that starts + at 0. e.g. the top channel is position 0. bitrate: :class:`int` The channel's preferred audio bitrate in bits per second. user_limit: :class:`int` @@ -1016,6 +1133,8 @@ class Guild(Hashable): The camera video quality for the voice channel's participants. .. versionadded:: 2.0 + reason: Optional[:class:`str`] + The reason for creating this channel. Shows up on the audit log. Raises ------ @@ -1031,7 +1150,25 @@ class Guild(Hashable): :class:`VoiceChannel` The channel that was just created. """ - data = await self._create_channel(name, overwrites, ChannelType.voice, category, reason=reason, **options) + options = {} + if position is not MISSING: + options['position'] = position + + if bitrate is not MISSING: + options['bitrate'] = bitrate + + if user_limit is not MISSING: + options['user_limit'] = user_limit + + if rtc_region is not MISSING: + options['rtc_region'] = None if rtc_region is None else str(rtc_region) + + if video_quality_mode is not MISSING: + options['video_quality_mode'] = video_quality_mode.value + + data = await self._create_channel( + name, overwrites=overwrites, channel_type=ChannelType.voice, category=category, reason=reason, **options + ) channel = VoiceChannel(state=self._state, guild=self, data=data) # temporarily add to the cache @@ -1042,22 +1179,38 @@ class Guild(Hashable): self, name: str, *, - reason: Optional[str] = ..., - category: Optional[CategoryChannel], topic: str, - position: int = ..., - overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., + position: int = MISSING, + overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING, + category: Optional[CategoryChannel] = None, + reason: Optional[str] = None, ) -> StageChannel: """|coro| This is similar to :meth:`create_text_channel` except makes a :class:`StageChannel` instead. - .. note:: - - The ``slowmode_delay`` and ``nsfw`` parameters are not supported in this function. - .. versionadded:: 1.7 + Parameters + ----------- + name: :class:`str` + The channel's name. + topic: :class:`str` + The new channel's topic. + overwrites: Dict[Union[:class:`Role`, :class:`Member`], :class:`PermissionOverwrite`] + A :class:`dict` of target (either a role or a member) to + :class:`PermissionOverwrite` to apply upon creation of a channel. + Useful for creating secret channels. + category: Optional[:class:`CategoryChannel`] + The category to place the newly created channel under. + The permissions will be automatically synced to category if no + overwrites are provided. + position: :class:`int` + The position in the channel list. This is a number that starts + at 0. e.g. the top channel is position 0. + reason: Optional[:class:`str`] + The reason for creating this channel. Shows up on the audit log. + Raises ------ Forbidden @@ -1072,7 +1225,16 @@ class Guild(Hashable): :class:`StageChannel` The channel that was just created. """ - data = await self._create_channel(name, overwrites, ChannelType.stage_voice, category, reason=reason, position=position, topic=topic) + + options: Dict[str, Any] = { + 'topic': topic, + } + if position is not MISSING: + options['position'] = position + + data = await self._create_channel( + name, overwrites=overwrites, channel_type=ChannelType.stage_voice, category=category, reason=reason, **options + ) channel = StageChannel(state=self._state, guild=self, data=data) # temporarily add to the cache @@ -1083,9 +1245,9 @@ class Guild(Hashable): self, name: str, *, - overwrites: Dict[Union[Role, Member], PermissionOverwrite] = None, + overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING, reason: Optional[str] = None, - position: int = None + position: int = MISSING, ) -> CategoryChannel: """|coro| @@ -1110,7 +1272,13 @@ class Guild(Hashable): :class:`CategoryChannel` The channel that was just created. """ - data = await self._create_channel(name, overwrites, ChannelType.category, reason=reason, position=position) + options: Dict[str, Any] = {} + if position is not MISSING: + options['position'] = position + + data = await self._create_channel( + name, overwrites=overwrites, channel_type=ChannelType.category, reason=reason, **options + ) channel = CategoryChannel(state=self._state, guild=self, data=data) # temporarily add to the cache @@ -1119,7 +1287,7 @@ class Guild(Hashable): create_category_channel = create_category - async def leave(self): + async def leave(self) -> None: """|coro| Leaves the guild. @@ -1136,7 +1304,7 @@ class Guild(Hashable): """ await self._state.http.leave_guild(self.id) - async def delete(self): + async def delete(self) -> None: """|coro| Deletes the guild. You must be the guild owner to delete the @@ -1152,38 +1320,32 @@ class Guild(Hashable): await self._state.http.delete_guild(self.id) - @overload async def edit( self, *, - reason: Optional[str] = ..., - name: str = ..., - description: Optional[str] = ..., - icon: Optional[bytes] = ..., - banner: Optional[bytes] = ..., - splash: Optional[bytes] = ..., - discovery_splash: Optional[bytes] = ..., - community: bool = ..., - region: Optional[VoiceRegion] = ..., - afk_channel: Optional[VoiceChannel] = ..., - afk_timeout: int = ..., - default_notifications: NotificationLevel = ..., - verification_level: VerificationLevel = ..., - explicit_content_filter: ContentFilter = ..., - vanity_code: str = ..., - system_channel: Optional[TextChannel] = ..., - system_channel_flags: SystemChannelFlags = ..., - preferred_locale: str = ..., - rules_channel: Optional[TextChannel] = ..., - public_updates_channel: Optional[TextChannel] = ..., + reason: Optional[str] = MISSING, + name: str = MISSING, + description: Optional[str] = MISSING, + icon: Optional[bytes] = MISSING, + banner: Optional[bytes] = MISSING, + splash: Optional[bytes] = MISSING, + discovery_splash: Optional[bytes] = MISSING, + community: bool = MISSING, + region: Optional[Union[str, VoiceRegion]] = MISSING, + afk_channel: Optional[VoiceChannel] = MISSING, + owner: Snowflake = MISSING, + afk_timeout: int = MISSING, + default_notifications: NotificationLevel = MISSING, + verification_level: VerificationLevel = MISSING, + explicit_content_filter: ContentFilter = MISSING, + vanity_code: str = MISSING, + system_channel: Optional[TextChannel] = MISSING, + system_channel_flags: SystemChannelFlags = MISSING, + preferred_locale: str = MISSING, + rules_channel: Optional[TextChannel] = MISSING, + public_updates_channel: Optional[TextChannel] = MISSING, ) -> None: ... - - @overload - async def edit(self) -> None: - ... - - async def edit(self, *, reason=None, **fields): r"""|coro| Edits the guild. @@ -1225,7 +1387,7 @@ class Guild(Hashable): community: :class:`bool` Whether the guild should be a Community guild. If set to ``True``\, both ``rules_channel`` and ``public_updates_channel`` parameters are required. - region: :class:`VoiceRegion` + region: Union[:class:`str`, :class:`VoiceRegion`] The new region for the guild's voice communication. afk_channel: Optional[:class:`VoiceChannel`] The new channel that is the AFK channel. Could be ``None`` for no AFK channel. @@ -1273,146 +1435,118 @@ class Guild(Hashable): """ http = self._state.http - try: - icon_bytes = fields['icon'] - except KeyError: - icon = self._icon - else: - if icon_bytes is not None: - icon = utils._bytes_to_base64_data(icon_bytes) - else: - icon = None - - try: - banner_bytes = fields['banner'] - except KeyError: - banner = self._banner - else: - if banner_bytes is not None: - banner = utils._bytes_to_base64_data(banner_bytes) - else: - banner = None - try: - vanity_code = fields['vanity_code'] - except KeyError: - pass - else: + if vanity_code is not MISSING: await http.change_vanity_code(self.id, vanity_code, reason=reason) - try: - splash_bytes = fields['splash'] - except KeyError: - splash = self._splash - else: - if splash_bytes is not None: - splash = utils._bytes_to_base64_data(splash_bytes) + fields: Dict[str, Any] = {} + if name is not MISSING: + fields['name'] = name + + if description is not MISSING: + fields['description'] = description + + if preferred_locale is not MISSING: + fields['preferred_locale'] = preferred_locale + + if afk_timeout is not MISSING: + fields['afk_timeout'] = afk_timeout + + if icon is not MISSING: + if icon is None: + fields['icon'] = icon else: - splash = None + fields['icon'] = utils._bytes_to_base64_data(icon) - try: - discovery_splash_bytes = fields['discovery_splash'] - except KeyError: - pass - else: - if discovery_splash_bytes is not None: - fields['discovery_splash'] = utils._bytes_to_base64_data(discovery_splash_bytes) + if banner is not MISSING: + if banner is None: + fields['banner'] = banner else: - fields['discovery_splash'] = None + fields['banner'] = utils._bytes_to_base64_data(banner) - fields['icon'] = icon - fields['banner'] = banner - fields['splash'] = splash + if splash is not MISSING: + if splash is None: + fields['splash'] = splash + else: + fields['splash'] = utils._bytes_to_base64_data(splash) - default_message_notifications = fields.get('default_notifications', self.default_notifications) - if not isinstance(default_message_notifications, NotificationLevel): - raise InvalidArgument('default_notifications field must be of type NotificationLevel') - fields['default_message_notifications'] = default_message_notifications.value + if discovery_splash is not MISSING: + if discovery_splash is None: + fields['discovery_splash'] = discovery_splash + else: + fields['discovery_splash'] = utils._bytes_to_base64_data(discovery_splash) - try: - afk_channel = fields.pop('afk_channel') - except KeyError: - pass - else: + if default_notifications is not MISSING: + if not isinstance(default_notifications, NotificationLevel): + raise InvalidArgument('default_notifications field must be of type NotificationLevel') + fields['default_message_notifications'] = default_notifications.value + + if afk_channel is not MISSING: if afk_channel is None: fields['afk_channel_id'] = afk_channel else: fields['afk_channel_id'] = afk_channel.id - try: - system_channel = fields.pop('system_channel') - except KeyError: - pass - else: + if system_channel is not MISSING: if system_channel is None: fields['system_channel_id'] = system_channel else: fields['system_channel_id'] = system_channel.id - if 'owner' in fields: - if self.owner_id != self._state.self_id: - raise InvalidArgument('To transfer ownership you must be the owner of the guild.') + if rules_channel is not MISSING: + if rules_channel is None: + fields['rules_channel_id'] = rules_channel + else: + fields['rules_channel_id'] = rules_channel.id - fields['owner_id'] = fields['owner'].id + if public_updates_channel is not MISSING: + if public_updates_channel is None: + fields['public_updates_channel_id'] = public_updates_channel + else: + fields['public_updates_channel_id'] = public_updates_channel.id - if 'region' in fields: - fields['region'] = str(fields['region']) + if owner is not MISSING: + if self.owner_id != self._state.self_id: + raise InvalidArgument('To transfer ownership you must be the owner of the guild.') - level = fields.get('verification_level', self.verification_level) - if not isinstance(level, VerificationLevel): - raise InvalidArgument('verification_level field must be of type VerificationLevel') + fields['owner_id'] = owner.id - fields['verification_level'] = level.value + if region is not MISSING: + fields['region'] = str(region) - explicit_content_filter = fields.get('explicit_content_filter', self.explicit_content_filter) - if not isinstance(explicit_content_filter, ContentFilter): - raise InvalidArgument('explicit_content_filter field must be of type ContentFilter') + if verification_level is not MISSING: + if not isinstance(verification_level, VerificationLevel): + raise InvalidArgument('verification_level field must be of type VerificationLevel') - fields['explicit_content_filter'] = explicit_content_filter.value + fields['verification_level'] = verification_level.value - system_channel_flags = fields.get('system_channel_flags', self.system_channel_flags) - if not isinstance(system_channel_flags, SystemChannelFlags): - raise InvalidArgument('system_channel_flags field must be of type SystemChannelFlags') + if explicit_content_filter is not MISSING: + if not isinstance(explicit_content_filter, ContentFilter): + raise InvalidArgument('explicit_content_filter field must be of type ContentFilter') - fields['system_channel_flags'] = system_channel_flags.value + fields['explicit_content_filter'] = explicit_content_filter.value - try: - rules_channel = fields.pop('rules_channel') - except KeyError: - pass - else: - if rules_channel is None: - fields['rules_channel_id'] = rules_channel - else: - fields['rules_channel_id'] = rules_channel.id + if system_channel_flags is not MISSING: + if not isinstance(system_channel_flags, SystemChannelFlags): + raise InvalidArgument('system_channel_flags field must be of type SystemChannelFlags') - try: - public_updates_channel = fields.pop('public_updates_channel') - except KeyError: - pass - else: - if public_updates_channel is None: - fields['public_updates_channel_id'] = public_updates_channel - else: - fields['public_updates_channel_id'] = public_updates_channel.id + fields['system_channel_flags'] = system_channel_flags.value - try: - community = fields.pop('community') - except KeyError: - pass - else: + if community is not MISSING: features = [] if community: if 'rules_channel_id' in fields and 'public_updates_channel_id' in fields: features.append('COMMUNITY') else: - raise InvalidArgument('community field requires both rules_channel and public_updates_channel fields to be provided') + raise InvalidArgument( + 'community field requires both rules_channel and public_updates_channel fields to be provided' + ) fields['features'] = features await http.edit_guild(self.id, reason=reason, **fields) - async def fetch_channels(self): + async def fetch_channels(self) -> Sequence[GuildChannel]: """|coro| Retrieves all :class:`abc.GuildChannel` that the guild has. @@ -1432,13 +1566,13 @@ class Guild(Hashable): Returns ------- - List[:class:`abc.GuildChannel`] + Sequence[:class:`abc.GuildChannel`] All channels in the guild. """ data = await self._state.http.get_all_guild_channels(self.id) def convert(d): - factory, ch_type = _channel_factory(d['type']) + factory, ch_type = _guild_channel_factory(d['type']) if factory is None: raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(d)) @@ -1447,7 +1581,8 @@ class Guild(Hashable): return [convert(d) for d in data] - def fetch_members(self, *, limit: int = 1000, after: Optional[SnowflakeTime] = None) -> List[Member]: + # TODO: Remove Optional typing here when async iterators are refactored + def fetch_members(self, *, limit: int = 1000, after: Optional[SnowflakeTime] = None) -> MemberIterator: """Retrieves an :class:`.AsyncIterator` that enables receiving the guild's members. In order to use this, :meth:`Intents.members` must be enabled. @@ -1500,7 +1635,7 @@ class Guild(Hashable): return MemberIterator(self, limit=limit, after=after) - async def fetch_member(self, member_id): + async def fetch_member(self, member_id: int, /) -> Member: """|coro| Retrieves a :class:`Member` from a guild ID, and a member ID. @@ -1529,7 +1664,7 @@ class Guild(Hashable): data = await self._state.http.get_member(self.id, member_id) return Member(data=data, state=self._state, guild=self) - async def fetch_ban(self, user): + async def fetch_ban(self, user: Snowflake) -> BanEntry: """|coro| Retrieves the :class:`BanEntry` for a user. @@ -1557,12 +1692,9 @@ class Guild(Hashable): The :class:`BanEntry` object for the specified user. """ data: BanPayload = await self._state.http.get_ban(user.id, self.id) - return BanEntry( - user=User(state=self._state, data=data['user']), - reason=data['reason'] - ) + return BanEntry(user=User(state=self._state, data=data['user']), reason=data['reason']) - async def fetch_channel(self, channel_id: int, /) -> abc.GuildChannel: + async def fetch_channel(self, channel_id: int, /) -> GuildChannel: """|coro| Retrieves a :class:`.abc.GuildChannel` with the specified ID. @@ -1593,7 +1725,7 @@ class Guild(Hashable): """ data = await self._state.http.get_channel(channel_id) - factory, ch_type = _channel_factory(data['type']) + factory, ch_type = _guild_channel_factory(data['type']) if factory is None: raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data)) @@ -1604,10 +1736,10 @@ class Guild(Hashable): if self.id != guild_id: raise InvalidData('Guild ID resolved to a different guild') - channel: abc.GuildChannel = factory(guild=self, state=self._state, data=data) # type: ignore + channel: GuildChannel = factory(guild=self, state=self._state, data=data) # type: ignore return channel - async def bans(self): + async def bans(self) -> List[BanEntry]: """|coro| Retrieves all the users that are banned from the guild as a :class:`list` of :class:`BanEntry`. @@ -1629,17 +1761,15 @@ class Guild(Hashable): """ data: List[BanPayload] = await self._state.http.get_bans(self.id) - return [BanEntry(user=User(state=self._state, data=e['user']), - reason=e['reason']) - for e in data] + return [BanEntry(user=User(state=self._state, data=e['user']), reason=e['reason']) for e in data] async def prune_members( self, *, days: int, compute_prune_count: bool = True, - roles: Optional[List[abc.Snowflake]] = None, - reason: Optional[str] = None + roles: List[Snowflake] = MISSING, + reason: Optional[str] = None, ) -> Optional[int]: r"""|coro| @@ -1670,7 +1800,7 @@ class Guild(Hashable): which makes it prone to timeouts in very large guilds. In order to prevent timeouts, you must set this to ``False``. If this is set to ``False``\, then this function will always return ``None``. - roles: Optional[List[:class:`abc.Snowflake`]] + roles: List[:class:`abc.Snowflake`] A list of :class:`abc.Snowflake` that represent roles to include in the pruning process. If a member has a role that is not specified, they'll be excluded. @@ -1694,12 +1824,16 @@ class Guild(Hashable): raise InvalidArgument(f'Expected int for ``days``, received {days.__class__.__name__} instead.') if roles: - roles = [str(role.id) for role in roles] + role_ids = [str(role.id) for role in roles] + else: + role_ids = [] - data = await self._state.http.prune_members(self.id, days, compute_prune_count=compute_prune_count, roles=roles, reason=reason) + data = await self._state.http.prune_members( + self.id, days, compute_prune_count=compute_prune_count, roles=role_ids, reason=reason + ) return data['pruned'] - async def templates(self): + async def templates(self) -> List[Template]: """|coro| Gets the list of templates from this guild. @@ -1719,10 +1853,11 @@ class Guild(Hashable): The templates for this guild. """ from .template import Template + data = await self._state.http.guild_templates(self.id) return [Template(data=d, state=self._state) for d in data] - async def webhooks(self): + async def webhooks(self) -> List[Webhook]: """|coro| Gets the list of webhooks from this guild. @@ -1741,10 +1876,11 @@ class Guild(Hashable): """ from .webhook import Webhook + data = await self._state.http.guild_webhooks(self.id) return [Webhook.from_state(d, state=self._state) for d in data] - async def estimate_pruned_members(self, *, days: int, roles: Optional[List[abc.Snowflake]] = None): + async def estimate_pruned_members(self, *, days: int, roles: List[Snowflake] = MISSING) -> int: """|coro| Similar to :meth:`prune_members` except instead of actually @@ -1755,7 +1891,7 @@ class Guild(Hashable): ----------- days: :class:`int` The number of days before counting as inactive. - roles: Optional[List[:class:`abc.Snowflake`]] + roles: List[:class:`abc.Snowflake`] A list of :class:`abc.Snowflake` that represent roles to include in the estimate. If a member has a role that is not specified, they'll be excluded. @@ -1780,9 +1916,11 @@ class Guild(Hashable): raise InvalidArgument(f'Expected int for ``days``, received {days.__class__.__name__} instead.') if roles: - roles = [str(role.id) for role in roles] + role_ids = [str(role.id) for role in roles] + else: + role_ids = [] - data = await self._state.http.estimate_pruned_members(self.id, days, roles) + data = await self._state.http.estimate_pruned_members(self.id, days, role_ids) return data['pruned'] async def invites(self) -> List[Invite]: @@ -1814,7 +1952,7 @@ class Guild(Hashable): return result - async def create_template(self, *, name: str, description: Optional[str] = None) -> Template: + async def create_template(self, *, name: str, description: str = MISSING) -> Template: """|coro| Creates a template for the guild. @@ -1828,14 +1966,12 @@ class Guild(Hashable): ----------- name: :class:`str` The name of the template. - description: Optional[:class:`str`] + description: :class:`str` The description of the template. """ from .template import Template - payload = { - 'name': name - } + payload = {'name': name} if description: payload['description'] = description @@ -1902,7 +2038,7 @@ class Guild(Hashable): return [convert(d) for d in data] - async def fetch_emojis(self): + async def fetch_emojis(self) -> List[Emoji]: r"""|coro| Retrieves all custom :class:`Emoji`\s from the guild. @@ -1924,7 +2060,7 @@ class Guild(Hashable): data = await self._state.http.get_all_custom_emojis(self.id) return [Emoji(guild=self, state=self._state, data=d) for d in data] - async def fetch_emoji(self, emoji_id): + async def fetch_emoji(self, emoji_id: int, /) -> Emoji: """|coro| Retrieves a custom :class:`Emoji` from the guild. @@ -1959,7 +2095,7 @@ class Guild(Hashable): *, name: str, image: bytes, - roles: Optional[List[Role]] = None, + roles: List[Role] = MISSING, reason: Optional[str] = None, ) -> Emoji: r"""|coro| @@ -1979,7 +2115,7 @@ class Guild(Hashable): image: :class:`bytes` The :term:`py:bytes-like object` representing the image data to use. Only JPG, PNG and GIF images are supported. - roles: Optional[List[:class:`Role`]] + roles: List[:class:`Role`] A :class:`list` of :class:`Role`\s that can use this emoji. Leave empty to make it available to everyone. reason: Optional[:class:`str`] The reason for creating this emoji. Shows up on the audit log. @@ -1999,11 +2135,14 @@ class Guild(Hashable): img = utils._bytes_to_base64_data(image) if roles: - roles = [role.id for role in roles] - data = await self._state.http.create_custom_emoji(self.id, name, img, roles=roles, reason=reason) + role_ids = [role.id for role in roles] + else: + role_ids = [] + + data = await self._state.http.create_custom_emoji(self.id, name, img, roles=role_ids, reason=reason) return self._state.store_emoji(self, data) - async def delete_emoji(self, emoji: abc.Snowflake, *, reason: Optional[str] = None) -> None: + async def delete_emoji(self, emoji: Snowflake, *, reason: Optional[str] = None) -> None: """|coro| Deletes the custom :class:`Emoji` from the guild. @@ -2028,7 +2167,7 @@ class Guild(Hashable): await self._state.http.delete_custom_emoji(self.id, emoji.id, reason=reason) - async def fetch_roles(self): + async def fetch_roles(self) -> List[Role]: """|coro| Retrieves all :class:`Role` that the guild has. @@ -2078,7 +2217,18 @@ class Guild(Hashable): ) -> Role: ... - async def create_role(self, *, reason=None, **fields): + async def create_role( + self, + *, + name: str = MISSING, + permissions: Permissions = MISSING, + color: Union[Colour, int] = MISSING, + colour: Union[Colour, int] = MISSING, + hoist: bool = MISSING, + mentionable: str = MISSING, + reason: Optional[str] = None, + ) -> Role: + ... """|coro| Creates a :class:`Role` for the guild. @@ -2123,27 +2273,26 @@ class Guild(Hashable): :class:`Role` The newly created role. """ - - try: - perms = fields.pop('permissions') - except KeyError: + fields: Dict[str, Any] = {} + if permissions is not MISSING: + fields['permissions'] = str(permissions.value) + else: fields['permissions'] = '0' + + actual_colour = colour or color or Colour.default() + if isinstance(actual_colour, int): + fields['color'] = actual_colour else: - fields['permissions'] = str(perms.value) + fields['color'] = actual_colour.value - try: - colour = fields.pop('colour') - except KeyError: - colour = fields.get('color', Colour.default()) - finally: - if isinstance(colour, int): - colour = Colour(value=colour) - fields['color'] = colour.value + if hoist is not MISSING: + fields['hoist'] = hoist + + if mentionable is not MISSING: + fields['mentionable'] = mentionable - valid_keys = ('name', 'permissions', 'color', 'hoist', 'mentionable') - for key in fields: - if key not in valid_keys: - raise InvalidArgument(f'{key!r} is not a valid field.') + if name is not MISSING: + fields['name'] = name data = await self._state.http.create_role(self.id, reason=reason, **fields) role = Role(guild=self, data=data, state=self._state) @@ -2151,7 +2300,7 @@ class Guild(Hashable): # TODO: add to cache return role - async def edit_role_positions(self, positions: Dict[abc.Snowflake, int], *, reason: Optional[str] = None) -> List[Role]: + async def edit_role_positions(self, positions: Dict[Snowflake, int], *, reason: Optional[str] = None) -> List[Role]: """|coro| Bulk edits a list of :class:`Role` in the guild. @@ -2198,18 +2347,15 @@ class Guild(Hashable): if not isinstance(positions, dict): raise InvalidArgument('positions parameter expects a dict.') - role_positions = [] + role_positions: List[Dict[str, Any]] = [] for role, position in positions.items(): - payload = { - 'id': role.id, - 'position': position - } + payload = {'id': role.id, 'position': position} role_positions.append(payload) data = await self._state.http.move_role_position(self.id, role_positions, reason=reason) - roles = [] + roles: List[Role] = [] for d in data: role = Role(guild=self, data=d, state=self._state) roles.append(role) @@ -2217,7 +2363,7 @@ class Guild(Hashable): return roles - async def kick(self, user: abc.Snowflake, *, reason: Optional[str] = None) -> None: + async def kick(self, user: Snowflake, *, reason: Optional[str] = None) -> None: """|coro| Kicks a user from the guild. @@ -2245,10 +2391,10 @@ class Guild(Hashable): async def ban( self, - user: abc.Snowflake, + user: Snowflake, *, reason: Optional[str] = None, - delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = 1 + delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = 1, ) -> None: """|coro| @@ -2278,7 +2424,7 @@ class Guild(Hashable): """ await self._state.http.ban(user.id, self.id, delete_message_days, reason=reason) - async def unban(self, user: abc.Snowflake, *, reason: Optional[str] = None) -> None: + async def unban(self, user: Snowflake, *, reason: Optional[str] = None) -> None: """|coro| Unbans a user from the guild. @@ -2341,6 +2487,7 @@ class Guild(Hashable): payload['max_age'] = 0 return Invite(state=self._state, data=payload, guild=self, channel=channel) + # TODO: use MISSING when async iterators get refactored def audit_logs( self, *, @@ -2348,8 +2495,8 @@ class Guild(Hashable): before: Optional[SnowflakeTime] = None, after: Optional[SnowflakeTime] = None, oldest_first: Optional[bool] = None, - user: abc.Snowflake = None, - action: AuditLogAction = None + user: Snowflake = None, + action: AuditLogAction = None, ) -> AuditLogIterator: """Returns an :class:`AsyncIterator` that enables receiving the guild's audit logs. @@ -2405,16 +2552,19 @@ class Guild(Hashable): :class:`AuditLogEntry` The audit log entry. """ - if user: - user = user.id + if user is not None: + user_id = user.id + else: + user_id = None if action: action = action.value - return AuditLogIterator(self, before=before, after=after, limit=limit, - oldest_first=oldest_first, user_id=user, action_type=action) + return AuditLogIterator( + self, before=before, after=after, limit=limit, oldest_first=oldest_first, user_id=user_id, action_type=action + ) - async def widget(self): + async def widget(self) -> Widget: """|coro| Returns the widget of the guild. @@ -2439,7 +2589,7 @@ class Guild(Hashable): return Widget(state=self._state, data=data) - async def edit_widget(self, *, enabled: bool = utils.MISSING, channel: Optional[abc.Snowflake] = utils.MISSING) -> None: + async def edit_widget(self, *, enabled: bool = MISSING, channel: Optional[Snowflake] = MISSING) -> None: """|coro| Edits the widget of the guild. @@ -2464,9 +2614,9 @@ class Guild(Hashable): Editing the widget failed. """ payload = {} - if channel is not utils.MISSING: + if channel is not MISSING: payload['channel_id'] = None if channel is None else channel.id - if enabled is not utils.MISSING: + if enabled is not MISSING: payload['enabled'] = enabled await self._state.http.edit_widget(self.id, payload=payload) @@ -2505,7 +2655,7 @@ class Guild(Hashable): limit: int = 5, user_ids: Optional[List[int]] = None, presences: bool = False, - cache: bool = True + cache: bool = True, ) -> List[Member]: """|coro| @@ -2570,9 +2720,13 @@ class Guild(Hashable): raise ValueError('user_ids must contain at least 1 value') limit = min(100, limit or 5) - return await self._state.query_members(self, query=query, limit=limit, user_ids=user_ids, presences=presences, cache=cache) + return await self._state.query_members( + self, query=query, limit=limit, user_ids=user_ids, presences=presences, cache=cache + ) - async def change_voice_state(self, *, channel: Optional[VocalGuildChannel], self_mute: bool = False, self_deaf: bool = False): + async def change_voice_state( + self, *, channel: Optional[VocalGuildChannel], self_mute: bool = False, self_deaf: bool = False + ): """|coro| Changes client's voice state in the guild. diff --git a/discord/state.py b/discord/state.py index 94a1a62a5..243f2b7bd 100644 --- a/discord/state.py +++ b/discord/state.py @@ -44,6 +44,7 @@ from .mentions import AllowedMentions from .partial_emoji import PartialEmoji from .message import Message from .channel import * +from .channel import _channel_factory from .raw_models import * from .member import Member from .role import Role