diff --git a/discord/channel.py b/discord/channel.py index 729f7f4da..f07ba6381 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -26,7 +26,7 @@ from __future__ import annotations import time import asyncio -from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union, overload +from typing import Any, Callable, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union, overload import datetime import discord.abc @@ -34,8 +34,9 @@ from .permissions import PermissionOverwrite, Permissions from .enums import ChannelType, StagePrivacyLevel, try_enum, VoiceRegion, VideoQualityMode from .mixins import Hashable from . import utils +from .utils import MISSING from .asset import Asset -from .errors import ClientException, NoMoreItems, InvalidArgument +from .errors import ClientException, InvalidArgument from .stage_instance import StageInstance from .threads import Thread from .iterators import ArchivedThreadIterator @@ -55,13 +56,27 @@ if TYPE_CHECKING: from .role import Role from .member import Member, VoiceState from .abc import Snowflake, SnowflakeTime - from .message import Message + from .message import Message, PartialMessage from .webhook import Webhook - -async def _single_delete_strategy(messages): + from .state import ConnectionState + from .user import ClientUser, User, BaseUser + from .guild import Guild, GuildChannel as GuildChannelType + from .types.channel import ( + TextChannel as TextChannelPayload, + VoiceChannel as VoiceChannelPayload, + StageChannel as StageChannelPayload, + DMChannel as DMChannelPayload, + CategoryChannel as CategoryChannelPayload, + StoreChannel as StoreChannelPayload, + GroupDMChannel as GroupChannelPayload, + ) + + +async def _single_delete_strategy(messages: Iterable[Message]): for m in messages: await m.delete() + class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """Represents a Discord guild text channel. @@ -114,56 +129,67 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead. """ - __slots__ = ('name', 'id', 'guild', 'topic', '_state', 'nsfw', - 'category_id', 'position', 'slowmode_delay', '_overwrites', - '_type', 'last_message_id') - - def __init__(self, *, state, guild, data): - self._state = state - self.id = int(data['id']) - self._type = data['type'] + __slots__ = ( + 'name', + 'id', + 'guild', + 'topic', + '_state', + 'nsfw', + 'category_id', + 'position', + 'slowmode_delay', + '_overwrites', + '_type', + 'last_message_id', + ) + + def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload): + self._state: ConnectionState = state + self.id: int = int(data['id']) + self._type: int = data['type'] self._update(guild, data) - def __repr__(self): + def __repr__(self) -> str: attrs = [ ('id', self.id), ('name', self.name), ('position', self.position), ('nsfw', self.nsfw), ('news', self.is_news()), - ('category_id', self.category_id) + ('category_id', self.category_id), ] joined = ' '.join('%s=%r' % t for t in attrs) return f'<{self.__class__.__name__} {joined}>' - def _update(self, guild, data): - self.guild = guild - self.name = data['name'] - self.category_id = utils._get_as_snowflake(data, 'parent_id') - self.topic = data.get('topic') - self.position = data['position'] - self.nsfw = data.get('nsfw', False) + def _update(self, guild: Guild, data: TextChannelPayload) -> None: + self.guild: Guild = guild + self.name: str = data['name'] + self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') + self.topic: Optional[str] = data.get('topic') + self.position: int = data['position'] + self.nsfw: bool = data.get('nsfw', False) # Does this need coercion into `int`? No idea yet. - self.slowmode_delay = data.get('rate_limit_per_user', 0) - self._type = data.get('type', self._type) - self.last_message_id = utils._get_as_snowflake(data, 'last_message_id') + self.slowmode_delay: int = data.get('rate_limit_per_user', 0) + self._type: int = data.get('type', self._type) + self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id') self._fill_overwrites(data) async def _get_channel(self): return self @property - def type(self): + def type(self) -> ChannelType: """:class:`ChannelType`: The channel's Discord type.""" return try_enum(ChannelType, self._type) @property - def _sorting_bucket(self): + def _sorting_bucket(self) -> int: return ChannelType.text.value @utils.copy_doc(discord.abc.GuildChannel.permissions_for) - def permissions_for(self, member): - base = super().permissions_for(member) + def permissions_for(self, obj: Union[Member, Role], /) -> Permissions: + base = super().permissions_for(obj) # text channels do not have voice related permissions denied = Permissions.voice() @@ -171,28 +197,28 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): return base @property - def members(self): + def members(self) -> List[Member]: """List[:class:`Member`]: Returns all members that can see this channel.""" return [m for m in self.guild.members if self.permissions_for(m).read_messages] @property - def threads(self): + def threads(self) -> List[Thread]: """List[:class:`Thread`]: Returns all the threads that you can see. .. versionadded:: 2.0 """ return [thread for thread in self.guild.threads if thread.parent_id == self.id] - def is_nsfw(self): + def is_nsfw(self) -> bool: """:class:`bool`: Checks if the channel is NSFW.""" return self.nsfw - def is_news(self): + def is_news(self) -> bool: """:class:`bool`: Checks if the channel is a news channel.""" return self._type == ChannelType.news.value @property - def last_message(self): + def last_message(self) -> Optional[Message]: """Fetches the last message from this channel in cache. The message might not be valid or point to an existing message. @@ -289,14 +315,12 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): await self._edit(options, reason=reason) @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: str = None, reason: str = None) -> TextChannel: - return await self._clone_impl({ - 'topic': self.topic, - 'nsfw': self.nsfw, - 'rate_limit_per_user': self.slowmode_delay - }, name=name, reason=reason) - - async def delete_messages(self, messages): + async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> TextChannel: + return await self._clone_impl( + {'topic': self.topic, 'nsfw': self.nsfw, 'rate_limit_per_user': self.slowmode_delay}, name=name, reason=reason + ) + + async def delete_messages(self, messages: Iterable[Snowflake]) -> None: """|coro| Deletes a list of messages. This is similar to :meth:`Message.delete` @@ -332,24 +356,24 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): messages = list(messages) if len(messages) == 0: - return # do nothing + return # do nothing if len(messages) == 1: - message_id = messages[0].id + message_id: int = messages[0].id await self._state.http.delete_message(self.id, message_id) return if len(messages) > 100: raise ClientException('Can only bulk delete messages up to 100 messages') - message_ids = [m.id for m in messages] + message_ids: List[int] = [m.id for m in messages] await self._state.http.delete_messages(self.id, message_ids) async def purge( self, *, limit: int = 100, - check: Callable[[Message], bool] = None, + check: Callable[[Message], bool] = MISSING, before: Optional[SnowflakeTime] = None, after: Optional[SnowflakeTime] = None, around: Optional[SnowflakeTime] = None, @@ -412,54 +436,52 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): The list of messages that were deleted. """ - if check is None: + if check is MISSING: check = lambda m: True iterator = self.history(limit=limit, before=before, after=after, oldest_first=oldest_first, around=around) - ret = [] + ret: List[Message] = [] count = 0 minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22 strategy = self.delete_messages if bulk else _single_delete_strategy - while True: - try: - msg = await iterator.next() - except NoMoreItems: - # no more messages to poll - if count >= 2: - # more than 2 messages -> bulk delete + async for message in iterator: + if count == 100: + to_delete = ret[-100:] + await strategy(to_delete) + count = 0 + await asyncio.sleep(1) + + if not check(message): + continue + + if message.id < minimum_time: + # older than 14 days old + if count == 1: + await ret[-1].delete() + elif count >= 2: to_delete = ret[-count:] await strategy(to_delete) - elif count == 1: - # delete a single message - await ret[-1].delete() - return ret - else: - if count == 100: - # we've reached a full 'queue' - to_delete = ret[-100:] - await strategy(to_delete) - count = 0 - await asyncio.sleep(1) + count = 0 + strategy = _single_delete_strategy - if check(msg): - if msg.id < minimum_time: - # older than 14 days old - if count == 1: - await ret[-1].delete() - elif count >= 2: - to_delete = ret[-count:] - await strategy(to_delete) + count += 1 + ret.append(message) - count = 0 - strategy = _single_delete_strategy + # SOme messages remaining to poll + if count >= 2: + # more than 2 messages -> bulk delete + to_delete = ret[-count:] + await strategy(to_delete) + elif count == 1: + # delete a single message + await ret[-1].delete() - count += 1 - ret.append(msg) + return ret - async def webhooks(self): + async def webhooks(self) -> List[Webhook]: """|coro| Gets the list of webhooks from this channel. @@ -478,10 +500,11 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """ from .webhook import Webhook + data = await self._state.http.channel_webhooks(self.id) return [Webhook.from_state(d, state=self._state) for d in data] - async def create_webhook(self, *, name: str, avatar: bytes = None, reason: str = None) -> Webhook: + async def create_webhook(self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None) -> Webhook: """|coro| Creates a webhook for this channel. @@ -515,8 +538,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """ from .webhook import Webhook + if avatar is not None: - avatar = utils._bytes_to_base64_data(avatar) + avatar = utils._bytes_to_base64_data(avatar) # type: ignore data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) return Webhook.from_state(data, state=self._state) @@ -563,10 +587,11 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): raise InvalidArgument(f'Expected TextChannel received {destination.__class__.__name__}') from .webhook import Webhook + data = await self._state.http.follow_webhook(self.id, webhook_channel_id=destination.id, reason=reason) return Webhook._as_follower(data, channel=destination, user=self._state.user) - def get_partial_message(self, message_id): + def get_partial_message(self, message_id: int, /) -> PartialMessage: """Creates a :class:`PartialMessage` from the message ID. This is useful if you want to work with a message and only have its ID without @@ -586,9 +611,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """ from .message import PartialMessage + return PartialMessage(channel=self, id=message_id) - def get_thread(self, thread_id: int) -> Optional[Thread]: + def get_thread(self, thread_id: int, /) -> Optional[Thread]: """Returns a thread with the given ID. .. versionadded:: 2.0 @@ -724,37 +750,47 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): # TODO: thread members? return [Thread(guild=self.guild, data=d) for d in data.get('threads', [])] -class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): - __slots__ = ('name', 'id', 'guild', 'bitrate', 'user_limit', - '_state', 'position', '_overwrites', 'category_id', - 'rtc_region', 'video_quality_mode') - def __init__(self, *, state, guild, data): - self._state = state - self.id = int(data['id']) +class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): + __slots__ = ( + 'name', + 'id', + 'guild', + 'bitrate', + 'user_limit', + '_state', + 'position', + '_overwrites', + 'category_id', + 'rtc_region', + 'video_quality_mode', + ) + + def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]): + self._state: ConnectionState = state + self.id: int = int(data['id']) self._update(guild, data) - def _get_voice_client_key(self): + def _get_voice_client_key(self) -> Tuple[int, str]: return self.guild.id, 'guild_id' - def _get_voice_state_pair(self): + def _get_voice_state_pair(self) -> Tuple[int, int]: return self.guild.id, self.id - def _update(self, guild, data): + def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None: self.guild = guild - self.name = data['name'] - self.rtc_region = data.get('rtc_region') - if self.rtc_region: - self.rtc_region = try_enum(VoiceRegion, self.rtc_region) - self.video_quality_mode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1)) - self.category_id = utils._get_as_snowflake(data, 'parent_id') - self.position = data['position'] - self.bitrate = data.get('bitrate') - self.user_limit = data.get('user_limit') + self.name: str = data['name'] + rtc = data.get('rtc_region') + self.rtc_region: Optional[VoiceRegion] = try_enum(VoiceRegion, rtc) if rtc is not None else None + self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1)) + self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') + self.position: int = data['position'] + self.bitrate: int = data.get('bitrate') + self.user_limit: int = data.get('user_limit') self._fill_overwrites(data) @property - def _sorting_bucket(self): + def _sorting_bucket(self) -> int: return ChannelType.voice.value @property @@ -787,8 +823,8 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha return {key: value for key, value in self.guild._voice_states.items() if value.channel.id == self.id} @utils.copy_doc(discord.abc.GuildChannel.permissions_for) - def permissions_for(self, member: Union[Role, Member], /) -> Permissions: - base = super().permissions_for(member) + def permissions_for(self, obj: Union[Member, Role], /) -> Permissions: + base = super().permissions_for(obj) # voice channels cannot be edited by people who can't connect to them # It also implicitly denies all other voice perms @@ -798,6 +834,7 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha base.value &= ~denied.value return base + class VoiceChannel(VocalGuildChannel): """Represents a Discord guild voice channel. @@ -849,7 +886,7 @@ class VoiceChannel(VocalGuildChannel): __slots__ = () - def __repr__(self): + def __repr__(self) -> str: attrs = [ ('id', self.id), ('name', self.name), @@ -858,28 +895,24 @@ class VoiceChannel(VocalGuildChannel): ('bitrate', self.bitrate), ('video_quality_mode', self.video_quality_mode), ('user_limit', self.user_limit), - ('category_id', self.category_id) + ('category_id', self.category_id), ] joined = ' '.join('%s=%r' % t for t in attrs) return f'<{self.__class__.__name__} {joined}>' @property - def type(self): + def type(self) -> ChannelType: """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.voice @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: str = None, reason: str = None) -> VoiceChannel: - return await self._clone_impl({ - 'bitrate': self.bitrate, - 'user_limit': self.user_limit - }, name=name, reason=reason) + async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> VoiceChannel: + return await self._clone_impl({'bitrate': self.bitrate, 'user_limit': self.user_limit}, name=name, reason=reason) @overload async def edit( self, *, - reason: Optional[str] = ..., name: str = ..., bitrate: int = ..., user_limit: int = ..., @@ -889,6 +922,7 @@ class VoiceChannel(VocalGuildChannel): overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., rtc_region: Optional[VoiceRegion] = ..., video_quality_mode: VideoQualityMode = ..., + reason: Optional[str] = ..., ) -> None: ... @@ -950,6 +984,7 @@ class VoiceChannel(VocalGuildChannel): await self._edit(options, reason=reason) + class StageChannel(VocalGuildChannel): """Represents a Discord guild stage channel. @@ -1000,9 +1035,10 @@ class StageChannel(VocalGuildChannel): .. versionadded:: 2.0 """ + __slots__ = ('topic',) - def __repr__(self): + def __repr__(self) -> str: attrs = [ ('id', self.id), ('name', self.name), @@ -1012,12 +1048,12 @@ class StageChannel(VocalGuildChannel): ('bitrate', self.bitrate), ('video_quality_mode', self.video_quality_mode), ('user_limit', self.user_limit), - ('category_id', self.category_id) + ('category_id', self.category_id), ] joined = ' '.join('%s=%r' % t for t in attrs) return f'<{self.__class__.__name__} {joined}>' - def _update(self, guild, data): + def _update(self, guild: Guild, data: StageChannelPayload) -> None: super()._update(guild, data) self.topic = data.get('topic') @@ -1032,7 +1068,9 @@ class StageChannel(VocalGuildChannel): .. versionadded:: 2.0 """ - return [member for member in self.members if not member.voice.suppress and member.voice.requested_to_speak_at is None] + return [ + member for member in self.members if not member.voice.suppress and member.voice.requested_to_speak_at is None + ] @property def listeners(self) -> List[Member]: @@ -1052,12 +1090,12 @@ class StageChannel(VocalGuildChannel): return [member for member in self.members if self.permissions_for(member) >= required_permissions] @property - def type(self): + def type(self) -> ChannelType: """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.stage_voice @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: str = None, reason: Optional[str] = None) -> StageChannel: + async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StageChannel: return await self._clone_impl({}, name=name, reason=reason) @property @@ -1068,7 +1106,7 @@ class StageChannel(VocalGuildChannel): """ return utils.get(self.guild.stage_instances, channel_id=self.id) - async def create_instance(self, *, topic: str, privacy_level: StagePrivacyLevel = utils.MISSING) -> StageInstance: + async def create_instance(self, *, topic: str, privacy_level: StagePrivacyLevel = MISSING) -> StageInstance: """|coro| Create a stage instance. @@ -1100,12 +1138,9 @@ class StageChannel(VocalGuildChannel): The newly created stage instance. """ - payload = { - 'channel_id': self.id, - 'topic': topic - } + payload: Dict[str, Any] = {'channel_id': self.id, 'topic': topic} - if privacy_level is not utils.MISSING: + if privacy_level is not MISSING: if not isinstance(privacy_level, StagePrivacyLevel): raise InvalidArgument('privacy_level field must be of type PrivacyLevel') @@ -1140,7 +1175,6 @@ class StageChannel(VocalGuildChannel): async def edit( self, *, - reason: Optional[str] = ..., name: str = ..., topic: Optional[str] = ..., position: int = ..., @@ -1149,6 +1183,7 @@ class StageChannel(VocalGuildChannel): overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., rtc_region: Optional[VoiceRegion] = ..., video_quality_mode: VideoQualityMode = ..., + reason: Optional[str] = ..., ) -> None: ... @@ -1203,6 +1238,8 @@ class StageChannel(VocalGuildChannel): """ await self._edit(options, reason=reason) + + class CategoryChannel(discord.abc.GuildChannel, Hashable): """Represents a Discord channel category. @@ -1247,50 +1284,48 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): __slots__ = ('name', 'id', 'guild', 'nsfw', '_state', 'position', '_overwrites', 'category_id') - def __init__(self, *, state, guild, data): - self._state = state - self.id = int(data['id']) + def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload): + self._state: ConnectionState = state + self.id: int = int(data['id']) self._update(guild, data) - def __repr__(self): + def __repr__(self) -> str: return f'' - def _update(self, guild, data): - self.guild = guild - self.name = data['name'] - self.category_id = utils._get_as_snowflake(data, 'parent_id') - self.nsfw = data.get('nsfw', False) - self.position = data['position'] + def _update(self, guild: Guild, data: CategoryChannelPayload) -> None: + self.guild: Guild = guild + self.name: str = data['name'] + self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') + self.nsfw: bool = data.get('nsfw', False) + self.position: int = data['position'] self._fill_overwrites(data) @property - def _sorting_bucket(self): + def _sorting_bucket(self) -> int: return ChannelType.category.value @property - def type(self): + def type(self) -> ChannelType: """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.category - def is_nsfw(self): + def is_nsfw(self) -> bool: """:class:`bool`: Checks if the category is NSFW.""" return self.nsfw @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: str = None, reason: Optional[str] = None) -> CategoryChannel: - return await self._clone_impl({ - 'nsfw': self.nsfw - }, name=name, reason=reason) + async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> CategoryChannel: + return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason) @overload async def edit( self, *, - reason: Optional[str] = ..., name: str = ..., position: int = ..., nsfw: bool = ..., overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., + reason: Optional[str] = ..., ) -> None: ... @@ -1341,11 +1376,12 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): await super().move(**kwargs) @property - def channels(self): + def channels(self) -> List[GuildChannelType]: """List[:class:`abc.GuildChannel`]: Returns the channels that are under this category. These are sorted by the official Discord UI, which places voice channels below the text channels. """ + def comparator(channel): return (not isinstance(channel, TextChannel), channel.position) @@ -1354,36 +1390,30 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): return ret @property - def text_channels(self): + def text_channels(self) -> List[TextChannel]: """List[:class:`TextChannel`]: Returns the text channels that are under this category.""" - ret = [c for c in self.guild.channels - if c.category_id == self.id - and isinstance(c, TextChannel)] + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, TextChannel)] ret.sort(key=lambda c: (c.position, c.id)) return ret @property - def voice_channels(self): + def voice_channels(self) -> List[VoiceChannel]: """List[:class:`VoiceChannel`]: Returns the voice channels that are under this category.""" - ret = [c for c in self.guild.channels - if c.category_id == self.id - and isinstance(c, VoiceChannel)] + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, VoiceChannel)] ret.sort(key=lambda c: (c.position, c.id)) return ret @property - def stage_channels(self): + def stage_channels(self) -> List[StageChannel]: """List[:class:`StageChannel`]: Returns the stage channels that are under this category. .. versionadded:: 1.7 """ - ret = [c for c in self.guild.channels - if c.category_id == self.id - and isinstance(c, StageChannel)] + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, StageChannel)] ret.sort(key=lambda c: (c.position, c.id)) return ret - async def create_text_channel(self, name, **options): + async def create_text_channel(self, name: str, **options: Any) -> TextChannel: """|coro| A shortcut method to :meth:`Guild.create_text_channel` to create a :class:`TextChannel` in the category. @@ -1395,7 +1425,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): """ return await self.guild.create_text_channel(name, category=self, **options) - async def create_voice_channel(self, name, **options): + async def create_voice_channel(self, name: str, **options: Any) -> VoiceChannel: """|coro| A shortcut method to :meth:`Guild.create_voice_channel` to create a :class:`VoiceChannel` in the category. @@ -1407,7 +1437,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): """ return await self.guild.create_voice_channel(name, category=self, **options) - async def create_stage_channel(self, name, **options): + async def create_stage_channel(self, name: str, **options: Any) -> StageChannel: """|coro| A shortcut method to :meth:`Guild.create_stage_channel` to create a :class:`StageChannel` in the category. @@ -1421,6 +1451,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): """ return await self.guild.create_stage_channel(name, category=self, **options) + class StoreChannel(discord.abc.GuildChannel, Hashable): """Represents a Discord guild store channel. @@ -1462,52 +1493,59 @@ class StoreChannel(discord.abc.GuildChannel, Hashable): To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead. """ - __slots__ = ('name', 'id', 'guild', '_state', 'nsfw', - 'category_id', 'position', '_overwrites',) - def __init__(self, *, state, guild, data): - self._state = state - self.id = int(data['id']) + __slots__ = ( + 'name', + 'id', + 'guild', + '_state', + 'nsfw', + 'category_id', + 'position', + '_overwrites', + ) + + def __init__(self, *, state: ConnectionState, guild: Guild, data: StoreChannelPayload): + self._state: ConnectionState = state + self.id: int = int(data['id']) self._update(guild, data) - def __repr__(self): + def __repr__(self) -> str: return f'' - def _update(self, guild, data): - self.guild = guild - self.name = data['name'] - self.category_id = utils._get_as_snowflake(data, 'parent_id') - self.position = data['position'] - self.nsfw = data.get('nsfw', False) + def _update(self, guild: Guild, data: StoreChannelPayload) -> None: + self.guild: Guild = guild + self.name: str = data['name'] + self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') + self.position: int = data['position'] + self.nsfw: bool = data.get('nsfw', False) self._fill_overwrites(data) @property - def _sorting_bucket(self): + def _sorting_bucket(self) -> int: return ChannelType.text.value @property - def type(self): + def type(self) -> ChannelType: """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.store @utils.copy_doc(discord.abc.GuildChannel.permissions_for) - def permissions_for(self, member): - base = super().permissions_for(member) + def permissions_for(self, obj: Union[Member, Role], /) -> Permissions: + base = super().permissions_for(obj) # store channels do not have voice related permissions denied = Permissions.voice() base.value &= ~denied.value return base - def is_nsfw(self): + def is_nsfw(self) -> bool: """:class:`bool`: Checks if the channel is NSFW.""" return self.nsfw @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: str = None, reason: Optional[str] = None) -> StoreChannel: - return await self._clone_impl({ - 'nsfw': self.nsfw - }, name=name, reason=reason) + async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StoreChannel: + return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason) @overload async def edit( @@ -1519,7 +1557,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable): sync_permissions: bool = ..., category: Optional[CategoryChannel], reason: Optional[str], - overwrites: Dict[Union[Role, Member], PermissionOverwrite] + overwrites: Dict[Union[Role, Member], PermissionOverwrite], ) -> None: ... @@ -1569,6 +1607,10 @@ class StoreChannel(discord.abc.GuildChannel, Hashable): """ await self._edit(options, reason=reason) + +DMC = TypeVar('DMC', bound='DMChannel') + + class DMChannel(discord.abc.Messageable, Hashable): """Represents a Discord direct message channel. @@ -1604,43 +1646,43 @@ class DMChannel(discord.abc.Messageable, Hashable): __slots__ = ('id', 'recipient', 'me', '_state') - def __init__(self, *, me, state, data): - self._state = state - self.recipient = state.store_user(data['recipients'][0]) - self.me = me - self.id = int(data['id']) + def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload): + self._state: ConnectionState = state + self.recipient: Optional[User] = state.store_user(data['recipients'][0]) + self.me: ClientUser = me + self.id: int = int(data['id']) async def _get_channel(self): return self - def __str__(self): + def __str__(self) -> str: if self.recipient: return f'Direct Message with {self.recipient}' return 'Direct Message with Unknown User' - def __repr__(self): + def __repr__(self) -> str: return f'' @classmethod - def _from_message(cls, state, channel_id): - self = cls.__new__(cls) + def _from_message(cls: Type[DMC], state: ConnectionState, channel_id: int) -> DMC: + self: DMC = cls.__new__(cls) self._state = state self.id = channel_id self.recipient = None - self.me = state.user + self.me = state.user # type: ignore return self @property - def type(self): + def type(self) -> ChannelType: """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.private @property - def created_at(self): + def created_at(self) -> datetime.datetime: """:class:`datetime.datetime`: Returns the direct message channel's creation time in UTC.""" return utils.snowflake_time(self.id) - def permissions_for(self, user=None): + def permissions_for(self, obj: Any = None, /) -> Permissions: """Handles permission resolution for a :class:`User`. This function is there for compatibility with other channel types. @@ -1654,9 +1696,9 @@ class DMChannel(discord.abc.Messageable, Hashable): Parameters ----------- - user: :class:`User` + obj: :class:`User` The user to check permissions for. This parameter is ignored - but kept for compatibility. + but kept for compatibility with other ``permissions_for`` methods. Returns -------- @@ -1670,7 +1712,7 @@ class DMChannel(discord.abc.Messageable, Hashable): base.manage_messages = False return base - def get_partial_message(self, message_id): + def get_partial_message(self, message_id: int, /) -> PartialMessage: """Creates a :class:`PartialMessage` from the message ID. This is useful if you want to work with a message and only have its ID without @@ -1690,8 +1732,10 @@ class DMChannel(discord.abc.Messageable, Hashable): """ from .message import PartialMessage + return PartialMessage(channel=self, id=message_id) + class GroupChannel(discord.abc.Messageable, Hashable): """Represents a Discord group channel. @@ -1721,39 +1765,40 @@ class GroupChannel(discord.abc.Messageable, Hashable): The user presenting yourself. id: :class:`int` The group channel ID. - owner: :class:`User` + owner: Optional[:class:`User`] The user that owns the group channel. + owner_id: :class:`int` + The owner ID that owns the group channel. + + .. versionadded:: 2.0 name: Optional[:class:`str`] The group channel's name if provided. """ - __slots__ = ('id', 'recipients', 'owner', '_icon', 'name', 'me', '_state') + __slots__ = ('id', 'recipients', 'owner_id', 'owner', '_icon', 'name', 'me', '_state') - def __init__(self, *, me, state, data): - self._state = state - self.id = int(data['id']) - self.me = me + def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload): + self._state: ConnectionState = state + self.id: int = int(data['id']) + self.me: ClientUser = me self._update_group(data) - def _update_group(self, data): - owner_id = utils._get_as_snowflake(data, 'owner_id') - self._icon = data.get('icon') - self.name = data.get('name') - - try: - self.recipients = [self._state.store_user(u) for u in data['recipients']] - except KeyError: - pass + def _update_group(self, data: GroupChannelPayload) -> None: + self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_id') + self._icon: Optional[str] = data.get('icon') + self.name: Optional[str] = data.get('name') + self.recipients: List[User] = [self._state.store_user(u) for u in data.get('recipients', [])] - if owner_id == self.me.id: + self.owner: Optional[BaseUser] + if self.owner_id == self.me.id: self.owner = self.me else: - self.owner = utils.find(lambda u: u.id == owner_id, self.recipients) + self.owner = utils.find(lambda u: u.id == self.owner_id, self.recipients) async def _get_channel(self): return self - def __str__(self): + def __str__(self) -> str: if self.name: return self.name @@ -1762,27 +1807,27 @@ class GroupChannel(discord.abc.Messageable, Hashable): return ', '.join(map(lambda x: x.name, self.recipients)) - def __repr__(self): + def __repr__(self) -> str: return f'' @property - def type(self): + def type(self) -> ChannelType: """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.group @property - def icon(self): + def icon(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the channel's icon asset if available.""" if self._icon is None: return None return Asset._from_icon(self._state, self.id, self._icon, path='channel') @property - def created_at(self): + def created_at(self) -> datetime.datetime: """:class:`datetime.datetime`: Returns the channel's creation time in UTC.""" return utils.snowflake_time(self.id) - def permissions_for(self, user): + def permissions_for(self, obj: Snowflake, /) -> Permissions: """Handles permission resolution for a :class:`User`. This function is there for compatibility with other channel types. @@ -1798,7 +1843,7 @@ class GroupChannel(discord.abc.Messageable, Hashable): Parameters ----------- - user: :class:`User` + obj: :class:`~discord.abc.Snowflake` The user to check permissions for. Returns @@ -1813,12 +1858,12 @@ class GroupChannel(discord.abc.Messageable, Hashable): base.manage_messages = False base.mention_everyone = True - if user.id == self.owner.id: + if obj.id == self.owner_id: base.kick_members = True return base - async def leave(self): + async def leave(self) -> None: """|coro| Leave the group. @@ -1833,11 +1878,13 @@ class GroupChannel(discord.abc.Messageable, Hashable): await self._state.http.leave_group(self.id) + def _coerce_channel_type(value: Union[ChannelType, int]) -> ChannelType: if isinstance(value, ChannelType): return value return try_enum(ChannelType, value) + def _guild_channel_factory(channel_type: Union[ChannelType, int]): value = _coerce_channel_type(channel_type) if value is ChannelType.text: @@ -1855,6 +1902,7 @@ def _guild_channel_factory(channel_type: Union[ChannelType, int]): else: return None, value + def _channel_factory(channel_type: Union[ChannelType, int]): cls, value = _guild_channel_factory(channel_type) if value is ChannelType.private: diff --git a/discord/guild.py b/discord/guild.py index 024aa5d87..b43b12df2 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -461,7 +461,7 @@ class Guild(Hashable): for c in channels: factory, ch_type = _guild_channel_factory(c['type']) if factory: - self._add_channel(factory(guild=self, data=c, state=self._state)) + self._add_channel(factory(guild=self, data=c, state=self._state)) # type: ignore if 'threads' in data: threads = data['threads']