diff --git a/discord/abc.py b/discord/abc.py index 88e4bd969..a785e35dd 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -26,7 +26,7 @@ from __future__ import annotations import copy import asyncio -from typing import Any, Dict, List, Mapping, Optional, TYPE_CHECKING, Protocol, TypeVar, Union, overload, runtime_checkable +from typing import Any, Dict, List, Mapping, Optional, TYPE_CHECKING, Protocol, Type, TypeVar, Union, overload, runtime_checkable from .iterators import HistoryIterator from .context_managers import Typing @@ -49,6 +49,8 @@ __all__ = ( 'Connectable', ) +T = TypeVar('T', bound=VoiceProtocol) + if TYPE_CHECKING: from datetime import datetime @@ -58,7 +60,8 @@ if TYPE_CHECKING: from .guild import Guild from .member import Member from .channel import CategoryChannel - + from .embeds import Embed + from .message import Message, MessageReference MISSING = utils.MISSING @@ -95,6 +98,7 @@ class Snowflake(Protocol): """:class:`datetime.datetime`: Returns the model's creation time as an aware datetime in UTC.""" raise NotImplementedError +SnowflakeTime = Union[Snowflake, datetime] @runtime_checkable class User(Snowflake, Protocol): @@ -653,14 +657,34 @@ class GuildChannel: """ await self._state.http.delete_channel(self.id, reason=reason) + @overload async def set_permissions( self, target: Union[Member, Role], *, - overwrite: Optional[PermissionOverwrite] = _undefined, - reason: Optional[str] = None, + overwrite: Optional[Union[PermissionOverwrite, _Undefined]] = ..., + reason: Optional[str] = ..., + ) -> None: + ... + + @overload + async def set_permissions( + self, + target: Union[Member, Role], + *, + reason: Optional[str] = ..., **permissions: bool, ) -> None: + ... + + async def set_permissions( + self, + target, + *, + overwrite=_undefined, + reason=None, + **permissions + ): r"""|coro| Sets the channel specific permission overwrites for a target in the @@ -815,7 +839,7 @@ class GuildChannel: offset: int = MISSING, category: Optional[Snowflake] = MISSING, sync_permissions: bool = MISSING, - reason: str = MISSING, + reason: Optional[str] = MISSING, ) -> None: ... @@ -1091,6 +1115,38 @@ class Messageable(Protocol): async def _get_channel(self): raise NotImplementedError + @overload + async def send( + self, + content: Optional[str] =..., + *, + tts: bool = ..., + embed: Embed = ..., + file: File = ..., + delete_after: int = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference] = ..., + mention_author: bool = ..., + ) -> Message: + ... + + @overload + async def send( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embed: Embed = ..., + files: List[File] = ..., + delete_after: int = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference] = ..., + mention_author: bool = ..., + ) -> Message: + ... + async def send(self, content=None, *, tts=False, embed=None, file=None, files=None, delete_after=None, nonce=None, allowed_mentions=None, reference=None, @@ -1402,7 +1458,7 @@ class Connectable(Protocol): def _get_voice_state_pair(self): raise NotImplementedError - async def connect(self, *, timeout=60.0, reconnect=True, cls=VoiceClient): + async def connect(self, *, timeout: float = 60.0, reconnect: bool = True, cls: Type[T] = VoiceClient) -> T: """|coro| Connects to voice and creates a :class:`VoiceClient` to establish diff --git a/discord/channel.py b/discord/channel.py index 74f274df8..c9eb9c345 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -22,11 +22,14 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import time import asyncio +from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union, overload import discord.abc -from .permissions import Permissions +from .permissions import PermissionOverwrite, Permissions from .enums import ChannelType, try_enum, VoiceRegion, VideoQualityMode from .mixins import Hashable from . import utils @@ -44,6 +47,14 @@ __all__ = ( '_channel_factory', ) +if TYPE_CHECKING: + from .role import Role + from .member import Member + from .abc import Snowflake + from .message import Message + from .webhook import Webhook + from .abc import SnowflakeTime + async def _single_delete_strategy(messages): for m in messages: await m.delete() @@ -190,6 +201,27 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """ return self._state._get_message(self.last_message_id) if self.last_message_id else None + @overload + async def edit( + self, + *, + reason: Optional[str] = ..., + name: str = ..., + topic: str = ..., + position: int = ..., + nsfw: bool = ..., + sync_permissions: bool = ..., + category: Optional[CategoryChannel] = ..., + slowmode_delay: int = ..., + type: ChannelType = ..., + overwrites: Dict[Union[Role, Member, Snowflake], PermissionOverwrite] = ..., + ) -> None: + ... + + @overload + async def edit(self) -> None: + ... + async def edit(self, *, reason=None, **options): """|coro| @@ -246,7 +278,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): await self._edit(options, reason=reason) @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name=None, reason=None): + async def clone(self, *, name: str = None, reason: str = None) -> TextChannel: return await self._clone_impl({ 'topic': self.topic, 'nsfw': self.nsfw, @@ -302,7 +334,17 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): message_ids = [m.id for m in messages] await self._state.http.delete_messages(self.id, message_ids) - async def purge(self, *, limit=100, check=None, before=None, after=None, around=None, oldest_first=False, bulk=True): + async def purge( + self, + *, + limit: int = 100, + check: Callable[[Message], bool] = None, + before: Optional[SnowflakeTime] = None, + after: Optional[SnowflakeTime] = None, + around: Optional[SnowflakeTime] = None, + oldest_first: Optional[bool] = False, + bulk: bool = True, + ) -> List[Message]: """|coro| Purges a list of messages that meet the criteria given by the predicate @@ -428,7 +470,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): data = await self._state.http.channel_webhooks(self.id) return [Webhook.from_state(d, state=self._state) for d in data] - async def create_webhook(self, *, name, avatar=None, reason=None): + async def create_webhook(self, *, name: str, avatar: bytes = None, reason: str = None) -> Webhook: """|coro| Creates a webhook for this channel. @@ -468,7 +510,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) return Webhook.from_state(data, state=self._state) - async def follow(self, *, destination, reason=None): + async def follow(self, *, destination: TextChannel, reason: Optional[str] = None) -> Webhook: """ Follows a channel using a webhook. @@ -680,12 +722,33 @@ class VoiceChannel(VocalGuildChannel): return ChannelType.voice @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name=None, reason=None): + async def clone(self, *, name: str = None, reason: str = None) -> VoiceChannel: return await self._clone_impl({ 'bitrate': self.bitrate, 'user_limit': self.user_limit }, name=name, reason=reason) + @overload + async def edit( + self, + *, + reason: Optional[str] = ..., + name: str = ..., + bitrate: int = ..., + user_limit: int = ..., + position: int = ..., + sync_permissions: int = ..., + category: Optional[CategoryChannel] = ..., + overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., + rtc_region: Optional[VoiceRegion] = ..., + video_quality_mode: VideoQualityMode = ..., + ) -> None: + ... + + @overload + async def edit(self) -> None: + ... + async def edit(self, *, reason=None, **options): """|coro| @@ -822,11 +885,31 @@ class StageChannel(VocalGuildChannel): return ChannelType.stage_voice @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name=None, reason=None): + async def clone(self, *, name: str = None, reason: Optional[str] = None) -> StageChannel: return await self._clone_impl({ 'topic': self.topic, }, name=name, reason=reason) + @overload + async def edit( + self, + *, + reason: Optional[str] = ..., + name: str = ..., + topic: Optional[str] = ..., + position: int = ..., + sync_permissions: int = ..., + category: Optional[CategoryChannel] = ..., + overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., + rtc_region: Optional[VoiceRegion] = ..., + video_quality_mode: VideoQualityMode = ..., + ) -> None: + ... + + @overload + async def edit(self) -> None: + ... + async def edit(self, *, reason=None, **options): """|coro| @@ -839,7 +922,7 @@ class StageChannel(VocalGuildChannel): ---------- name: :class:`str` The new channel's name. - topic: :class:`str` + topic: Optional[:class:`str`] The new channel's topic. position: :class:`int` The new channel's position. @@ -873,7 +956,6 @@ class StageChannel(VocalGuildChannel): """ await self._edit(options, reason=reason) - class CategoryChannel(discord.abc.GuildChannel, Hashable): """Represents a Discord channel category. @@ -948,11 +1030,27 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): return self.nsfw or self.guild.nsfw @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name=None, reason=None): + async def clone(self, *, name: str = None, reason: Optional[str] = None) -> CategoryChannel: return await self._clone_impl({ 'nsfw': self.nsfw }, name=name, reason=reason) + @overload + async def edit( + self, + *, + reason: Optional[str] = ..., + name: str = ..., + position: int = ..., + nsfw: bool = ..., + overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., + ) -> None: + ... + + @overload + async def edit(self) -> None: + ... + async def edit(self, *, reason=None, **options): """|coro| @@ -1159,11 +1257,29 @@ class StoreChannel(discord.abc.GuildChannel, Hashable): return self.nsfw or self.guild.nsfw @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name=None, reason=None): + async def clone(self, *, name: str = None, reason: Optional[str] = None) -> StoreChannel: return await self._clone_impl({ 'nsfw': self.nsfw }, name=name, reason=reason) + @overload + async def edit( + self, + *, + name: str = ..., + position: int = ..., + nsfw: bool = ..., + sync_permissions: bool = ..., + category: Optional[CategoryChannel], + reason: Optional[str], + overwrites: Dict[Union[Role, Member], PermissionOverwrite] + ) -> None: + ... + + @overload + async def edit(self) -> None: + ... + async def edit(self, *, reason=None, **options): """|coro| diff --git a/discord/client.py b/discord/client.py index f30ed7f54..494a838a4 100644 --- a/discord/client.py +++ b/discord/client.py @@ -22,12 +22,14 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import asyncio import logging import signal import sys import traceback -from typing import Any, Optional, Union +from typing import Any, List, Optional, TYPE_CHECKING, Union import aiohttp @@ -58,6 +60,9 @@ __all__ = ( 'Client', ) +if TYPE_CHECKING: + from .abc import SnowflakeTime + log = logging.getLogger(__name__) def _cancel_tasks(loop): @@ -968,7 +973,7 @@ class Client: # Guild stuff - def fetch_guilds(self, *, limit=100, before=None, after=None): + def fetch_guilds(self, *, limit: int = 100, before: SnowflakeTime = None, after: SnowflakeTime = None) -> List[Guild]: """Retrieves an :class:`.AsyncIterator` that enables receiving your guilds. .. note:: diff --git a/discord/colour.py b/discord/colour.py index 5baf379fa..764cf6e2e 100644 --- a/discord/colour.py +++ b/discord/colour.py @@ -27,7 +27,7 @@ import random from typing import ( Any, - Optional, + Optional, Tuple, Type, TypeVar, @@ -65,7 +65,7 @@ class Colour: .. describe:: str(x) Returns the hex format for the colour. - + .. describe:: int(x) Returns the raw colour value. @@ -95,7 +95,7 @@ class Colour: def __str__(self) -> str: return f'#{self.value:0>6x}' - + def __int__(self) -> int: return self.value diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 0f3c546ee..c5367c24f 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -143,7 +143,7 @@ class Context(discord.abc.Messageable): ret = await command.callback(*arguments, **kwargs) return ret - async def reinvoke(self, *, call_hooks=False, restart=True): + async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True): """|coro| Calls the command again. diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 14dd5659c..d1b9a901d 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -782,7 +782,7 @@ class clean_content(Converter[str]): .. versionadded:: 1.7 """ - def __init__(self, *, fix_channel_mentions=False, use_nicknames=True, escape_markdown=False, remove_markdown=False): + def __init__(self, *, fix_channel_mentions: bool = False, use_nicknames: bool = True, escape_markdown: bool = False, remove_markdown: bool = False) -> None: self.fix_channel_mentions = fix_channel_mentions self.use_nicknames = use_nicknames self.escape_markdown = escape_markdown diff --git a/discord/guild.py b/discord/guild.py index 7034da76c..b8ba0ddd5 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE. import copy from collections import namedtuple -from typing import List, Optional, TYPE_CHECKING, overload +from typing import Dict, List, Literal, Optional, TYPE_CHECKING, Union, overload from . import utils, abc from .role import Role @@ -35,7 +35,7 @@ from .permissions import PermissionOverwrite from .colour import Colour from .errors import InvalidArgument, ClientException from .channel import * -from .enums import VoiceRegion, ChannelType, try_enum, VerificationLevel, ContentFilter, NotificationLevel +from .enums import AuditLogAction, VideoQualityMode, VoiceRegion, ChannelType, try_enum, VerificationLevel, ContentFilter, NotificationLevel from .mixins import Hashable from .user import User from .invite import Invite @@ -53,6 +53,11 @@ if TYPE_CHECKING: from .types.guild import ( Ban as BanPayload ) + from .permissions import Permissions + from .channel import VoiceChannel, StageChannel + from .template import Template + + VocalGuildChannel = Union[VoiceChannel, StageChannel] BanEntry = namedtuple('BanEntry', 'reason user') _GuildLimit = namedtuple('_GuildLimit', 'emoji bitrate filesize') @@ -765,6 +770,28 @@ class Guild(Hashable): return self._state.http.create_channel(self.id, channel_type.value, name=name, parent_id=parent_id, permission_overwrites=perms, **options) + @overload + async def create_text_channel( + self, + name: str, + *, + reason: Optional[str] = ..., + category: Optional[CategoryChannel], + position: int = ..., + topic: Optional[str] = ..., + slowmode_delay: int = ..., + nsfw: bool = ..., + overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., + ) -> TextChannel: + ... + + @overload + async def create_text_channel( + self, + name: str + ) -> TextChannel: + ... + async def create_text_channel(self, name, *, overwrites=None, category=None, reason=None, **options): """|coro| @@ -850,6 +877,29 @@ class Guild(Hashable): self._channels[channel.id] = channel return channel + @overload + async def create_voice_channel( + self, + name: str, + *, + reason: Optional[str] = ..., + category: Optional[CategoryChannel], + position: int = ..., + bitrate: int = ..., + user_limit: int = ..., + rtc_region: Optional[VoiceRegion] = ..., + voice_quality_mode: VideoQualityMode = ..., + overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., + ) -> VoiceChannel: + ... + + @overload + async def create_voice_channel( + self, + name: str + ) -> VoiceChannel: + ... + async def create_voice_channel(self, name, *, overwrites=None, category=None, reason=None, **options): """|coro| @@ -893,7 +943,16 @@ class Guild(Hashable): self._channels[channel.id] = channel return channel - async def create_stage_channel(self, name, *, topic=None, category=None, overwrites=None, reason=None, position=None): + async def create_stage_channel( + self, + name: str, + *, + reason: Optional[str] = ..., + category: Optional[CategoryChannel], + topic: str, + position: int = ..., + overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., + ) -> StageChannel: """|coro| This is similar to :meth:`create_text_channel` except makes a :class:`StageChannel` instead. @@ -925,7 +984,14 @@ class Guild(Hashable): self._channels[channel.id] = channel return channel - async def create_category(self, name, *, overwrites=None, reason=None, position=None): + async def create_category( + self, + name: str, + *, + overwrites: Dict[Union[Role, Member], PermissionOverwrite] = None, + reason: Optional[str] = None, + position: int = None + ) -> CategoryChannel: """|coro| Same as :meth:`create_text_channel` except makes a :class:`CategoryChannel` instead. @@ -1286,7 +1352,7 @@ class Guild(Hashable): return [convert(d) for d in data] - def fetch_members(self, *, limit=1000, after=None): + def fetch_members(self, *, limit: int = 1000, after: Optional[abc.SnowflakeTime] = None) -> List[Member]: """Retrieves an :class:`.AsyncIterator` that enables receiving the guild's members. In order to use this, :meth:`Intents.members` must be enabled. @@ -1472,7 +1538,14 @@ class Guild(Hashable): reason=e['reason']) for e in data] - async def prune_members(self, *, days, compute_prune_count=True, roles=None, reason=None): + async def prune_members( + self, + *, + days: int, + compute_prune_count: bool = True, + roles: Optional[List[abc.Snowflake]] = None, + reason: Optional[str] = None + ) -> Optional[int]: r"""|coro| Prunes the guild from its inactive members. @@ -1576,7 +1649,7 @@ class Guild(Hashable): data = await self._state.http.guild_webhooks(self.id) return [Webhook.from_state(d, state=self._state) for d in data] - async def estimate_pruned_members(self, *, days, roles=None): + async def estimate_pruned_members(self, *, days: int, roles: Optional[List[abc.Snowflake]] = None): """|coro| Similar to :meth:`prune_members` except instead of actually @@ -1648,7 +1721,7 @@ class Guild(Hashable): return result - async def create_template(self, *, name, description=None): + async def create_template(self, *, name: str, description: Optional[str] = None) -> Template: """|coro| Creates a template for the guild. @@ -1678,7 +1751,7 @@ class Guild(Hashable): return Template(state=self._state, data=data) - async def create_integration(self, *, type, id): + async def create_integration(self, *, type: str, id: int) -> None: """|coro| Attaches an integration to the guild. @@ -1704,7 +1777,7 @@ class Guild(Hashable): """ await self._state.http.create_integration(self.id, type, id) - async def integrations(self): + async def integrations(self) -> List[Integration]: """|coro| Returns a list of all integrations attached to the guild. @@ -1781,7 +1854,14 @@ class Guild(Hashable): data = await self._state.http.get_custom_emoji(self.id, emoji_id) return Emoji(guild=self, state=self._state, data=data) - async def create_custom_emoji(self, *, name, image, roles=None, reason=None): + async def create_custom_emoji( + self, + *, + name: str, + image: bytes, + roles: Optional[List[Role]] = None, + reason: Optional[str] = None, + ) -> Emoji: r"""|coro| Creates a custom :class:`Emoji` for the guild. @@ -1847,6 +1927,32 @@ class Guild(Hashable): data = await self._state.http.get_roles(self.id) return [Role(guild=self, state=self._state, data=d) for d in data] + @overload + async def create_role( + self, + *, + reason: Optional[str] = ..., + name: str = ..., + permissions: Permissions = ..., + colour: Union[Colour, int] = ..., + hoist: bool = ..., + mentionable: str = ..., + ) -> Role: + ... + + @overload + async def create_role( + self, + *, + reason: Optional[str] = ..., + name: str = ..., + permissions: Permissions = ..., + color: Union[Colour, int] = ..., + hoist: bool = ..., + mentionable: str = ..., + ) -> Role: + ... + async def create_role(self, *, reason=None, **fields): """|coro| @@ -1920,7 +2026,7 @@ class Guild(Hashable): # TODO: add to cache return role - async def edit_role_positions(self, positions, *, reason=None): + async def edit_role_positions(self, positions: Dict[abc.Snowflake, int], *, reason: Optional[str] = None) -> List[Role]: """|coro| Bulk edits a list of :class:`Role` in the guild. @@ -1986,7 +2092,7 @@ class Guild(Hashable): return roles - async def kick(self, user, *, reason=None): + async def kick(self, user: abc.Snowflake, *, reason: Optional[str] = None) -> None: """|coro| Kicks a user from the guild. @@ -2012,7 +2118,13 @@ class Guild(Hashable): """ await self._state.http.kick(user.id, self.id, reason=reason) - async def ban(self, user, *, reason=None, delete_message_days=1): + async def ban( + self, + user: abc.Snowflake, + *, + reason: Optional[str] = None, + delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = 1 + ) -> None: """|coro| Bans a user from the guild. @@ -2041,7 +2153,7 @@ class Guild(Hashable): """ await self._state.http.ban(user.id, self.id, delete_message_days, reason=reason) - async def unban(self, user, *, reason=None): + async def unban(self, user: abc.Snowflake, *, reason: Optional[str] = None) -> None: """|coro| Unbans a user from the guild. @@ -2105,7 +2217,16 @@ class Guild(Hashable): payload['max_age'] = 0 return Invite(state=self._state, data=payload) - def audit_logs(self, *, limit=100, before=None, after=None, oldest_first=None, user=None, action=None): + def audit_logs( + self, + *, + limit: int = 100, + before: Optional[abc.SnowflakeTime] = None, + after: Optional[abc.SnowflakeTime] = None, + oldest_first: Optional[bool] = None, + user: abc.Snowflake = None, + action: AuditLogAction = None + ) -> AuditLogIterator: """Returns an :class:`AsyncIterator` that enables receiving the guild's audit logs. You must have the :attr:`~Permissions.view_audit_log` permission to use this. @@ -2194,7 +2315,7 @@ class Guild(Hashable): return Widget(state=self._state, data=data) - async def chunk(self, *, cache=True): + async def chunk(self, *, cache: bool = True) -> None: """|coro| Requests all members that belong to this guild. In order to use this, @@ -2221,7 +2342,15 @@ class Guild(Hashable): if not self._state.is_guild_evicted(self): return await self._state.chunk_guild(self, cache=cache) - async def query_members(self, query=None, *, limit=5, user_ids=None, presences=False, cache=True): + async def query_members( + self, + query: Optional[str] = None, + *, + limit: int = 5, + user_ids: Optional[List[int]] = None, + presences: bool = False, + cache: bool = True + ) -> List[Member]: """|coro| Request members that belong to this guild whose username starts with @@ -2287,7 +2416,7 @@ class Guild(Hashable): limit = min(100, limit or 5) return await self._state.query_members(self, query=query, limit=limit, user_ids=user_ids, presences=presences, cache=cache) - async def change_voice_state(self, *, channel, self_mute=False, self_deaf=False): + async def change_voice_state(self, *, channel: Optional[VocalGuildChannel], self_mute: bool = False, self_deaf: bool = False): """|coro| Changes client's voice state in the guild. diff --git a/discord/member.py b/discord/member.py index 4eeebcba8..008de3650 100644 --- a/discord/member.py +++ b/discord/member.py @@ -27,11 +27,11 @@ import inspect import itertools import sys from operator import attrgetter +from typing import List, Literal, Optional, TYPE_CHECKING, Union, overload import discord.abc from . import utils -from .errors import ClientException from .user import BaseUser, User from .activity import create_activity from .permissions import Permissions @@ -44,6 +44,12 @@ __all__ = ( 'Member', ) +if TYPE_CHECKING: + from .channel import VoiceChannel, StageChannel + from .abc import Snowflake + + VocalGuildChannel = Union[VoiceChannel, StageChannel] + class VoiceState: """Represents a Discord user's voice state. @@ -517,6 +523,19 @@ class Member(discord.abc.Messageable, _BaseUser): """Optional[:class:`VoiceState`]: Returns the member's current voice state.""" return self.guild._voice_state_for(self._user.id) + @overload + async def ban( + self, + *, + reason: Optional[str] = ..., + delete_message_days: Literal[1, 2, 3, 4, 5, 6, 7] = ..., + ) -> None: + ... + + @overload + async def ban(self) -> None: + ... + async def ban(self, **kwargs): """|coro| @@ -524,20 +543,38 @@ class Member(discord.abc.Messageable, _BaseUser): """ await self.guild.ban(self, **kwargs) - async def unban(self, *, reason=None): + async def unban(self, *, reason: Optional[str] = None) -> None: """|coro| Unbans this member. Equivalent to :meth:`Guild.unban`. """ await self.guild.unban(self, reason=reason) - async def kick(self, *, reason=None): + async def kick(self, *, reason: Optional[str] = None) -> None: """|coro| Kicks this member. Equivalent to :meth:`Guild.kick`. """ await self.guild.kick(self, reason=reason) + @overload + async def edit( + self, + *, + reason: Optional[str] = ..., + nick: Optional[str] = None, + mute: bool = ..., + deafen: bool = ..., + suppress: bool = ..., + roles: Optional[List[discord.abc.Snowflake]] = ..., + voice_channel: Optional[VocalGuildChannel] = ..., + ) -> None: + ... + + @overload + async def edit(self) -> None: + ... + async def edit(self, *, reason=None, **fields): """|coro| @@ -685,7 +722,7 @@ class Member(discord.abc.Messageable, _BaseUser): else: await self._state.http.edit_my_voice_state(self.guild.id, payload) - async def move_to(self, channel, *, reason=None): + async def move_to(self, channel: VocalGuildChannel, *, reason: Optional[str] = None) -> None: """|coro| Moves a member to a new voice channel (they must be connected first). @@ -708,7 +745,7 @@ class Member(discord.abc.Messageable, _BaseUser): """ await self.edit(voice_channel=channel, reason=reason) - async def add_roles(self, *roles, reason=None, atomic=True): + async def add_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True): r"""|coro| Gives the member a number of :class:`Role`\s. @@ -747,7 +784,7 @@ class Member(discord.abc.Messageable, _BaseUser): for role in roles: await req(guild_id, user_id, role.id, reason=reason) - async def remove_roles(self, *roles, reason=None, atomic=True): + async def remove_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None: r"""|coro| Removes :class:`Role`\s from this member. diff --git a/discord/message.py b/discord/message.py index 68c78001e..9016960a0 100644 --- a/discord/message.py +++ b/discord/message.py @@ -29,7 +29,7 @@ import datetime import re import io from os import PathLike -from typing import TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar +from typing import TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload from . import utils from .reaction import Reaction @@ -63,6 +63,7 @@ if TYPE_CHECKING: from .abc import GuildChannel from .state import ConnectionState from .channel import TextChannel, GroupChannel, DMChannel + from .mentions import AllowedMentions EmojiInputType = Union[Emoji, PartialEmoji, str] @@ -398,7 +399,7 @@ class MessageReference: return self @classmethod - def from_message(cls, message: Message, *, fail_if_not_exists: bool = True): + def from_message(cls, message: Message, *, fail_if_not_exists: bool = True) -> MessageReference: """Creates a :class:`MessageReference` from an existing :class:`~discord.Message`. .. versionadded:: 1.6 @@ -1077,7 +1078,24 @@ class Message(Hashable): else: await self._state.http.delete_message(self.channel.id, self.id) - async def edit(self, **fields: Any) -> None: + @overload + async def edit( + self, + *, + content: Optional[str] = ..., + embed: Optional[Embed] = ..., + attachments: List[Attachment] = ..., + suppress: bool = ..., + delete_after: Optional[float] = ..., + allowed_mentions: Optional[AllowedMentions] = ..., + ) -> None: + ... + + @overload + async def edit(self) -> None: + ... + + async def edit(self, **fields) -> None: """|coro| Edits the message. diff --git a/discord/partial_emoji.py b/discord/partial_emoji.py index 83fc91921..c9ddd6504 100644 --- a/discord/partial_emoji.py +++ b/discord/partial_emoji.py @@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations + from typing import Any, Dict, Optional, TYPE_CHECKING, Type, TypeVar from .asset import Asset, AssetMixin @@ -36,7 +37,7 @@ __all__ = ( if TYPE_CHECKING: from .state import ConnectionState from datetime import datetime - + from .types.message import PartialEmoji as PartialEmojiPayload class _EmojiTag: __slots__ = () diff --git a/discord/role.py b/discord/role.py index 491fd075b..22778f8e9 100644 --- a/discord/role.py +++ b/discord/role.py @@ -22,6 +22,8 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from typing import Optional, Union, overload + from .permissions import Permissions from .errors import InvalidArgument from .colour import Colour @@ -305,6 +307,24 @@ class Role(Hashable): payload = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)] await http.move_role_position(self.guild.id, payload, reason=reason) + @overload + async def edit( + self, + *, + reason: Optional[str] = ..., + name: str = ..., + permissions: Permissions = ..., + colour: Union[Colour, int] = ..., + hoist: bool = ..., + mentionable: bool = ..., + position: int = ..., + ) -> None: + ... + + @overload + async def edit(self) -> None: + ... + async def edit(self, *, reason=None, **fields): """|coro| @@ -371,7 +391,7 @@ class Role(Hashable): data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload) self._update(data) - async def delete(self, *, reason=None): + async def delete(self, *, reason: Optional[str] = None): """|coro| Deletes the role. diff --git a/discord/user.py b/discord/user.py index 95bcea25a..09e8e43af 100644 --- a/discord/user.py +++ b/discord/user.py @@ -22,12 +22,11 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import TYPE_CHECKING - +from typing import Optional, TYPE_CHECKING import discord.abc from .flags import PublicUserFlags from .utils import snowflake_time, _bytes_to_base64_data -from .enums import DefaultAvatar, try_enum +from .enums import DefaultAvatar from .colour import Colour from .asset import Asset @@ -248,7 +247,7 @@ class ClientUser(BaseUser): self._flags = data.get('flags', 0) self.mfa_enabled = data.get('mfa_enabled', False) - async def edit(self, *, username=None, avatar=None): + async def edit(self, *, username: str = None, avatar: Optional[bytes] = None) -> None: """|coro| Edits the current profile of the client. diff --git a/discord/voice_client.py b/discord/voice_client.py index 246ed7b6a..2ae2a8b12 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -42,6 +42,7 @@ import socket import logging import struct import threading +from typing import Any, Callable from . import opus, utils from .backoff import ExponentialBackoff @@ -121,7 +122,7 @@ class VoiceProtocol: """ raise NotImplementedError - async def connect(self, *, timeout, reconnect): + async def connect(self, *, timeout: float, reconnect: bool): """|coro| An abstract method called when the client initiates the connection request. @@ -144,7 +145,7 @@ class VoiceProtocol: """ raise NotImplementedError - async def disconnect(self, *, force): + async def disconnect(self, *, force: bool): """|coro| An abstract method called when the client terminates the connection. @@ -328,7 +329,7 @@ class VoiceClient(VoiceProtocol): self._connected.set() return ws - async def connect(self, *, reconnect, timeout): + async def connect(self, *, reconnect: bool, timeout: bool): log.info('Connecting to voice...') self.timeout = timeout @@ -451,7 +452,7 @@ class VoiceClient(VoiceProtocol): log.warning('Could not connect to voice... Retrying...') continue - async def disconnect(self, *, force=False): + async def disconnect(self, *, force: bool = False): """|coro| Disconnects this voice client from voice. @@ -525,7 +526,7 @@ class VoiceClient(VoiceProtocol): return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4] - def play(self, source, *, after=None): + def play(self, source: AudioSource, *, after: Callable[[Exception], Any]=None): """Plays an :class:`AudioSource`. The finalizer, ``after`` is called after the source has been exhausted diff --git a/discord/widget.py b/discord/widget.py index c10b55cd7..36b6e3dd5 100644 --- a/discord/widget.py +++ b/discord/widget.py @@ -175,7 +175,7 @@ class WidgetMember(BaseUser): else: activity = create_activity(game) - self.activity: Optional[Union[BaseActivity, Spotify]] = activity + self.activity: Optional[Union[BaseActivity, Spotify]] = activity self.connected_channel: Optional[WidgetChannel] = connected_channel @@ -277,10 +277,10 @@ class Widget: """Optional[:class:`str`]: The invite URL for the guild, if available.""" return self._invite - async def fetch_invite(self, *, with_counts: bool = True) -> Optional[Invite]: + async def fetch_invite(self, *, with_counts: bool = True) -> Invite: """|coro| - Retrieves an :class:`Invite` from a invite URL or ID. + Retrieves an :class:`Invite` from the widget's invite URL. This is the same as :meth:`Client.fetch_invite`; the invite code is abstracted away. @@ -293,10 +293,9 @@ class Widget: Returns -------- - Optional[:class:`Invite`] - The invite from the URL/ID. + :class:`Invite` + The invite from the widget's invite URL. """ - if self._invite: - invite_id = resolve_invite(self._invite) - data = await self._state.http.get_invite(invite_id, with_counts=with_counts) - return Invite.from_incomplete(state=self._state, data=data) + invite_id = resolve_invite(self._invite) + data = await self._state.http.get_invite(invite_id, with_counts=with_counts) + return Invite.from_incomplete(state=self._state, data=data)