diff --git a/discord/client.py b/discord/client.py index 3d73d001e..e49cc2aeb 100644 --- a/discord/client.py +++ b/discord/client.py @@ -78,7 +78,8 @@ from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factor if TYPE_CHECKING: from .types.guild import Guild as GuildPayload - from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake + from .abc import SnowflakeTime, Snowflake, PrivateChannel + from .guild import GuildChannel from .channel import DMChannel from .message import Message from .member import Member diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 3563064f1..c75a589d0 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -33,7 +33,7 @@ import importlib.util import sys import traceback import types -from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union +from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union, overload import discord @@ -65,6 +65,7 @@ MISSING: Any = discord.utils.MISSING T = TypeVar('T') CFT = TypeVar('CFT', bound='CoroFunc') CXT = TypeVar('CXT', bound='Context') +BT = TypeVar('BT', bound='Union[Bot, AutoShardedBot]') def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: @@ -932,7 +933,15 @@ class BotBase(GroupMixin): return ret - async def get_context(self, message: Message, *, cls: Type[CXT] = Context) -> CXT: + @overload + async def get_context(self: BT, message: Message) -> Context[BT]: + ... + + @overload + async def get_context(self, message: Message, *, cls: Type[CXT] = ...) -> CXT: + ... + + async def get_context(self, message: Message, *, cls: Type[Context] = Context) -> Any: r"""|coro| Returns the invocation context from the message. diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 5c233dbc6..d6094d373 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -41,6 +41,7 @@ from typing import ( Tuple, Union, runtime_checkable, + overload, ) import discord @@ -48,7 +49,8 @@ from .errors import * if TYPE_CHECKING: from .context import Context - from discord.message import PartialMessageableChannel + from discord.state import Channel + from discord.threads import Thread from .bot import Bot, AutoShardedBot _Bot = Union[Bot, AutoShardedBot] @@ -357,7 +359,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): @staticmethod def _resolve_channel( ctx: Context[_Bot], guild_id: Optional[int], channel_id: Optional[int] - ) -> Optional[PartialMessageableChannel]: + ) -> Optional[Union[Channel, Thread]]: if channel_id is None: # we were passed just a message id so we can assume the channel is the current context channel return ctx.channel @@ -373,8 +375,8 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialMessage: guild_id, message_id, channel_id = self._get_id_matches(ctx, argument) channel = self._resolve_channel(ctx, guild_id, channel_id) - if not channel: - raise ChannelNotFound(channel_id) + if not channel or not isinstance(channel, discord.abc.Messageable): + raise ChannelNotFound(channel_id) # type: ignore - channel_id won't be None here return discord.PartialMessage(channel=channel, id=message_id) @@ -399,14 +401,14 @@ class MessageConverter(IDConverter[discord.Message]): if message: return message channel = PartialMessageConverter._resolve_channel(ctx, guild_id, channel_id) - if not channel: - raise ChannelNotFound(channel_id) + if not channel or not isinstance(channel, discord.abc.Messageable): + raise ChannelNotFound(channel_id) # type: ignore - channel_id won't be None here try: return await channel.fetch_message(message_id) except discord.NotFound: raise MessageNotFound(argument) except discord.Forbidden: - raise ChannelNotReadable(channel) + raise ChannelNotReadable(channel) # type: ignore - type-checker thinks channel could be a DMChannel at this point class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): @@ -449,7 +451,8 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): else: channel_id = int(match.group(1)) if guild: - result = guild.get_channel(channel_id) + # guild.get_channel returns an explicit union instead of the base class + result = guild.get_channel(channel_id) # type: ignore else: result = _get_from_guilds(bot, 'get_channel', channel_id) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index c420ca4fe..16adcca8d 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -99,7 +99,7 @@ __all__ = ( MISSING: Any = discord.utils.MISSING T = TypeVar('T') -CogT = TypeVar('CogT', bound='Cog') +CogT = TypeVar('CogT', bound='Optional[Cog]') CommandT = TypeVar('CommandT', bound='Command') ContextT = TypeVar('ContextT', bound='Context') # CHT = TypeVar('CHT', bound='Check') @@ -307,7 +307,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): Callable[Concatenate[ContextT, P], Coro[T]], ], **kwargs: Any, - ): + ) -> None: if not asyncio.iscoroutinefunction(func): raise TypeError('Callback must be a coroutine.') @@ -372,7 +372,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.require_var_positional: bool = kwargs.get('require_var_positional', False) self.ignore_extra: bool = kwargs.get('ignore_extra', True) self.cooldown_after_parsing: bool = kwargs.get('cooldown_after_parsing', False) - self.cog: Optional[CogT] = None + self.cog: CogT = None # bandaid for the fact that sometimes parent can be the bot instance parent = kwargs.get('parent') @@ -1321,9 +1321,8 @@ class GroupMixin(Generic[CogT]): @overload def command( - self, + self: GroupMixin[CogT], name: str = ..., - cls: Type[Command[CogT, P, T]] = ..., *args: Any, **kwargs: Any, ) -> Callable[ @@ -1339,21 +1338,29 @@ class GroupMixin(Generic[CogT]): @overload def command( - self, + self: GroupMixin[CogT], name: str = ..., cls: Type[CommandT] = ..., *args: Any, **kwargs: Any, - ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], CommandT]: + ) -> Callable[ + [ + Union[ + Callable[Concatenate[CogT, ContextT, P], Coro[T]], + Callable[Concatenate[ContextT, P], Coro[T]], + ] + ], + CommandT, + ]: ... def command( self, name: str = MISSING, - cls: Type[CommandT] = MISSING, + cls: Type[Command] = MISSING, *args: Any, **kwargs: Any, - ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], CommandT]: + ) -> Any: """A shortcut decorator that invokes :func:`.command` and adds it to the internal command list via :meth:`~.GroupMixin.add_command`. @@ -1363,7 +1370,8 @@ class GroupMixin(Generic[CogT]): A decorator that converts the provided method into a Command, adds it to the bot, then returns it. """ - def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> CommandT: + def decorator(func): + kwargs.setdefault('parent', self) result = command(name=name, cls=cls, *args, **kwargs)(func) self.add_command(result) @@ -1373,34 +1381,46 @@ class GroupMixin(Generic[CogT]): @overload def group( - self, + self: GroupMixin[CogT], name: str = ..., - cls: Type[Group[CogT, P, T]] = ..., *args: Any, **kwargs: Any, ) -> Callable[ - [Union[Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]]]], + [ + Union[ + Callable[Concatenate[CogT, ContextT, P], Coro[T]], + Callable[Concatenate[ContextT, P], Coro[T]], + ] + ], Group[CogT, P, T], ]: ... @overload def group( - self, + self: GroupMixin[CogT], name: str = ..., cls: Type[GroupT] = ..., *args: Any, **kwargs: Any, - ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], GroupT]: + ) -> Callable[ + [ + Union[ + Callable[Concatenate[CogT, ContextT, P], Coro[T]], + Callable[Concatenate[ContextT, P], Coro[T]], + ] + ], + GroupT, + ]: ... def group( self, name: str = MISSING, - cls: Type[GroupT] = MISSING, + cls: Type[Group] = MISSING, *args: Any, **kwargs: Any, - ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], GroupT]: + ) -> Any: """A shortcut decorator that invokes :func:`.group` and adds it to the internal command list via :meth:`~.GroupMixin.add_command`. @@ -1410,7 +1430,7 @@ class GroupMixin(Generic[CogT]): A decorator that converts the provided method into a Group, adds it to the bot, then returns it. """ - def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> GroupT: + def decorator(func): kwargs.setdefault('parent', self) result = group(name=name, cls=cls, *args, **kwargs)(func) self.add_command(result) @@ -1533,21 +1553,39 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]): # Decorators +if TYPE_CHECKING: + # Using a class to emulate a function allows for overloading the inner function in the decorator. + + class _CommandDecorator: + @overload + def __call__(self, func: Callable[Concatenate[CogT, ContextT, P], Coro[T]], /) -> Command[CogT, P, T]: + ... + + @overload + def __call__(self, func: Callable[Concatenate[ContextT, P], Coro[T]], /) -> Command[None, P, T]: + ... + + def __call__(self, func: Callable[..., Coro[T]], /) -> Any: + ... + + class _GroupDecorator: + @overload + def __call__(self, func: Callable[Concatenate[CogT, ContextT, P], Coro[T]], /) -> Group[CogT, P, T]: + ... + + @overload + def __call__(self, func: Callable[Concatenate[ContextT, P], Coro[T]], /) -> Group[None, P, T]: + ... + + def __call__(self, func: Callable[..., Coro[T]], /) -> Any: + ... + @overload def command( name: str = ..., - cls: Type[Command[CogT, P, T]] = ..., **attrs: Any, -) -> Callable[ - [ - Union[ - Callable[Concatenate[CogT, ContextT, P], Coro[T]], - Callable[Concatenate[ContextT, P], Coro[T]], - ] - ], - Command[CogT, P, T], -]: +) -> _CommandDecorator: ... @@ -1559,8 +1597,8 @@ def command( ) -> Callable[ [ Union[ - Callable[Concatenate[CogT, ContextT, P], Coro[Any]], Callable[Concatenate[ContextT, P], Coro[Any]], + Callable[Concatenate[CogT, ContextT, P], Coro[Any]], # type: ignore - CogT is used here to allow covariance ] ], CommandT, @@ -1570,17 +1608,9 @@ def command( def command( name: str = MISSING, - cls: Type[CommandT] = MISSING, + cls: Type[Command] = MISSING, **attrs: Any, -) -> Callable[ - [ - Union[ - Callable[Concatenate[ContextT, P], Coro[Any]], - Callable[Concatenate[CogT, ContextT, P], Coro[T]], - ] - ], - Union[Command[CogT, P, T], CommandT], -]: +) -> Any: """A decorator that transforms a function into a :class:`.Command` or if called with :func:`.group`, :class:`.Group`. @@ -1611,14 +1641,9 @@ def command( If the function is not a coroutine or is already a command. """ if cls is MISSING: - cls = Command # type: ignore + cls = Command - def decorator( - func: Union[ - Callable[Concatenate[ContextT, P], Coro[Any]], - Callable[Concatenate[CogT, ContextT, P], Coro[Any]], - ] - ) -> CommandT: + def decorator(func): if isinstance(func, Command): raise TypeError('Callback is already a command.') return cls(func, name=name, **attrs) @@ -1629,17 +1654,8 @@ def command( @overload def group( name: str = ..., - cls: Type[Group[CogT, P, T]] = ..., **attrs: Any, -) -> Callable[ - [ - Union[ - Callable[Concatenate[CogT, ContextT, P], Coro[T]], - Callable[Concatenate[ContextT, P], Coro[T]], - ] - ], - Group[CogT, P, T], -]: +) -> _GroupDecorator: ... @@ -1651,7 +1667,7 @@ def group( ) -> Callable[ [ Union[ - Callable[Concatenate[CogT, ContextT, P], Coro[Any]], + Callable[Concatenate[CogT, ContextT, P], Coro[Any]], # type: ignore - CogT is used here to allow covariance Callable[Concatenate[ContextT, P], Coro[Any]], ] ], @@ -1662,17 +1678,9 @@ def group( def group( name: str = MISSING, - cls: Type[GroupT] = MISSING, + cls: Type[Group] = MISSING, **attrs: Any, -) -> Callable[ - [ - Union[ - Callable[Concatenate[ContextT, P], Coro[Any]], - Callable[Concatenate[CogT, ContextT, P], Coro[T]], - ] - ], - Union[Group[CogT, P, T], GroupT], -]: +) -> Any: """A decorator that transforms a function into a :class:`.Group`. This is similar to the :func:`.command` decorator but the ``cls`` @@ -1682,8 +1690,9 @@ def group( The ``cls`` parameter can now be passed. """ if cls is MISSING: - cls = Group # type: ignore - return command(name=name, cls=cls, **attrs) # type: ignore + cls = Group + + return command(name=name, cls=cls, **attrs) def check(predicate: Check) -> Callable[[T], T]: diff --git a/discord/ext/commands/view.py b/discord/ext/commands/view.py index 6ba31e1a7..b86298822 100644 --- a/discord/ext/commands/view.py +++ b/discord/ext/commands/view.py @@ -21,7 +21,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError # map from opening quotes to closing quotes @@ -177,7 +176,7 @@ class StringView: next_char = self.get() valid_eof = not next_char or next_char.isspace() if not valid_eof: - raise InvalidEndOfQuotedStringError(next_char) + raise InvalidEndOfQuotedStringError(next_char) # type: ignore - this will always be a string # we're quoted so it's okay return ''.join(result)