diff --git a/discord/ext/commands/_types.py b/discord/ext/commands/_types.py index 8c3c53a28..9b1559870 100644 --- a/discord/ext/commands/_types.py +++ b/discord/ext/commands/_types.py @@ -22,6 +22,26 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + +from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union + + +if TYPE_CHECKING: + from .context import Context + from .cog import Cog + from .errors import CommandError + +T = TypeVar('T') + +Coro = Coroutine[Any, Any, T] +MaybeCoro = Union[T, Coro[T]] +CoroFunc = Callable[..., Coro[Any]] + +Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]] +Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]] +Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]] + + # This is merely a tag type to avoid circular import issues. # Yes, this is a terrible solution but ultimately it is the only solution. class _BaseCommand: diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 7c49bf961..ba108153b 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -22,13 +22,18 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + + import asyncio import collections +import collections.abc import inspect import importlib.util import sys import traceback import types +from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union import discord @@ -39,6 +44,15 @@ from . import errors from .help import HelpCommand, DefaultHelpCommand from .cog import Cog +if TYPE_CHECKING: + import importlib.machinery + + from discord.message import Message + from ._types import ( + Check, + CoroFunc, + ) + __all__ = ( 'when_mentioned', 'when_mentioned_or', @@ -46,14 +60,21 @@ __all__ = ( 'AutoShardedBot', ) -def when_mentioned(bot, msg): +MISSING: Any = discord.utils.MISSING + +T = TypeVar('T') +CFT = TypeVar('CFT', bound='CoroFunc') +CXT = TypeVar('CXT', bound='Context') + +def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: """A callable that implements a command prefix equivalent to being mentioned. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. """ - return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] + # bot.user will never be None when this is called + return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore -def when_mentioned_or(*prefixes): +def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]: """A callable that implements when mentioned or other prefixes provided. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. @@ -89,7 +110,7 @@ def when_mentioned_or(*prefixes): return inner -def _is_submodule(parent, child): +def _is_submodule(parent: str, child: str) -> bool: return parent == child or child.startswith(parent + ".") class _DefaultRepr: @@ -102,10 +123,10 @@ class BotBase(GroupMixin): def __init__(self, command_prefix, help_command=_default, description=None, **options): super().__init__(**options) self.command_prefix = command_prefix - self.extra_events = {} - self.__cogs = {} - self.__extensions = {} - self._checks = [] + self.extra_events: Dict[str, List[CoroFunc]] = {} + self.__cogs: Dict[str, Cog] = {} + self.__extensions: Dict[str, types.ModuleType] = {} + self._checks: List[Check] = [] self._check_once = [] self._before_invoke = None self._after_invoke = None @@ -128,13 +149,14 @@ class BotBase(GroupMixin): # internal helpers - def dispatch(self, event_name, *args, **kwargs): - super().dispatch(event_name, *args, **kwargs) + def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None: + # super() will resolve to Client + super().dispatch(event_name, *args, **kwargs) # type: ignore ev = 'on_' + event_name for event in self.extra_events.get(ev, []): - self._schedule_event(event, ev, *args, **kwargs) + self._schedule_event(event, ev, *args, **kwargs) # type: ignore - async def close(self): + async def close(self) -> None: for extension in tuple(self.__extensions): try: self.unload_extension(extension) @@ -147,9 +169,9 @@ class BotBase(GroupMixin): except Exception: pass - await super().close() + await super().close() # type: ignore - async def on_command_error(self, context, exception): + async def on_command_error(self, context: Context, exception: errors.CommandError) -> None: """|coro| The default command error handler provided by the bot. @@ -175,7 +197,7 @@ class BotBase(GroupMixin): # global check registration - def check(self, func): + def check(self, func: T) -> T: r"""A decorator that adds a global check to the bot. A global check is similar to a :func:`.check` that is applied @@ -200,10 +222,11 @@ class BotBase(GroupMixin): return ctx.command.qualified_name in allowed_commands """ - self.add_check(func) + # T was used instead of Check to ensure the type matches on return + self.add_check(func) # type: ignore return func - def add_check(self, func, *, call_once=False): + def add_check(self, func: Check, *, call_once: bool = False) -> None: """Adds a global check to the bot. This is the non-decorator interface to :meth:`.check` @@ -223,7 +246,7 @@ class BotBase(GroupMixin): else: self._checks.append(func) - def remove_check(self, func, *, call_once=False): + def remove_check(self, func: Check, *, call_once: bool = False) -> None: """Removes a global check from the bot. This function is idempotent and will not raise an exception @@ -244,7 +267,7 @@ class BotBase(GroupMixin): except ValueError: pass - def check_once(self, func): + def check_once(self, func: CFT) -> CFT: r"""A decorator that adds a "call once" global check to the bot. Unlike regular global checks, this one is called only once @@ -282,15 +305,16 @@ class BotBase(GroupMixin): self.add_check(func, call_once=True) return func - async def can_run(self, ctx, *, call_once=False): + async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool: data = self._check_once if call_once else self._checks if len(data) == 0: return True - return await discord.utils.async_all(f(ctx) for f in data) + # type-checker doesn't distinguish between functions and methods + return await discord.utils.async_all(f(ctx) for f in data) # type: ignore - async def is_owner(self, user): + async def is_owner(self, user: discord.User) -> bool: """|coro| Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of @@ -319,7 +343,8 @@ class BotBase(GroupMixin): elif self.owner_ids: return user.id in self.owner_ids else: - app = await self.application_info() + + app = await self.application_info() # type: ignore if app.team: self.owner_ids = ids = {m.id for m in app.team.members} return user.id in ids @@ -327,7 +352,7 @@ class BotBase(GroupMixin): self.owner_id = owner_id = app.owner.id return user.id == owner_id - def before_invoke(self, coro): + def before_invoke(self, coro: CFT) -> CFT: """A decorator that registers a coroutine as a pre-invoke hook. A pre-invoke hook is called directly before the command is @@ -359,7 +384,7 @@ class BotBase(GroupMixin): self._before_invoke = coro return coro - def after_invoke(self, coro): + def after_invoke(self, coro: CFT) -> CFT: r"""A decorator that registers a coroutine as a post-invoke hook. A post-invoke hook is called directly after the command is @@ -394,14 +419,14 @@ class BotBase(GroupMixin): # listener registration - def add_listener(self, func, name=None): + def add_listener(self, func: CoroFunc, name: str = MISSING) -> None: """The non decorator alternative to :meth:`.listen`. Parameters ----------- func: :ref:`coroutine ` The function to call. - name: Optional[:class:`str`] + name: :class:`str` The name of the event to listen for. Defaults to ``func.__name__``. Example @@ -416,7 +441,7 @@ class BotBase(GroupMixin): bot.add_listener(my_message, 'on_message') """ - name = func.__name__ if name is None else name + name = func.__name__ if name is MISSING else name if not asyncio.iscoroutinefunction(func): raise TypeError('Listeners must be coroutines') @@ -426,7 +451,7 @@ class BotBase(GroupMixin): else: self.extra_events[name] = [func] - def remove_listener(self, func, name=None): + def remove_listener(self, func: CoroFunc, name: str = MISSING) -> None: """Removes a listener from the pool of listeners. Parameters @@ -438,7 +463,7 @@ class BotBase(GroupMixin): ``func.__name__``. """ - name = func.__name__ if name is None else name + name = func.__name__ if name is MISSING else name if name in self.extra_events: try: @@ -446,7 +471,7 @@ class BotBase(GroupMixin): except ValueError: pass - def listen(self, name=None): + def listen(self, name: str = MISSING) -> Callable[[CFT], CFT]: """A decorator that registers another function as an external event listener. Basically this allows you to listen to multiple events from different places e.g. such as :func:`.on_ready` @@ -476,7 +501,7 @@ class BotBase(GroupMixin): The function being listened to is not a coroutine. """ - def decorator(func): + def decorator(func: CFT) -> CFT: self.add_listener(func, name) return func @@ -528,7 +553,7 @@ class BotBase(GroupMixin): cog = cog._inject(self) self.__cogs[cog_name] = cog - def get_cog(self, name): + def get_cog(self, name: str) -> Optional[Cog]: """Gets the cog instance requested. If the cog is not found, ``None`` is returned instead. @@ -547,7 +572,7 @@ class BotBase(GroupMixin): """ return self.__cogs.get(name) - def remove_cog(self, name): + def remove_cog(self, name: str) -> Optional[Cog]: """Removes a cog from the bot and returns it. All registered commands and event listeners that the @@ -578,13 +603,13 @@ class BotBase(GroupMixin): return cog @property - def cogs(self): + def cogs(self) -> Mapping[str, Cog]: """Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog.""" return types.MappingProxyType(self.__cogs) # extensions - def _remove_module_references(self, name): + def _remove_module_references(self, name: str) -> None: # find all references to the module # remove the cogs registered from the module for cogname, cog in self.__cogs.copy().items(): @@ -608,7 +633,7 @@ class BotBase(GroupMixin): for index in reversed(remove): del event_list[index] - def _call_module_finalizers(self, lib, key): + def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: try: func = getattr(lib, 'teardown') except AttributeError: @@ -626,12 +651,12 @@ class BotBase(GroupMixin): if _is_submodule(name, module): del sys.modules[module] - def _load_from_module_spec(self, spec, key): + def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None: # precondition: key not in self.__extensions lib = importlib.util.module_from_spec(spec) sys.modules[key] = lib try: - spec.loader.exec_module(lib) + spec.loader.exec_module(lib) # type: ignore except Exception as e: del sys.modules[key] raise errors.ExtensionFailed(key, e) from e @@ -652,13 +677,13 @@ class BotBase(GroupMixin): else: self.__extensions[key] = lib - def _resolve_name(self, name, package): + def _resolve_name(self, name: str, package: Optional[str]) -> str: try: return importlib.util.resolve_name(name, package) except ImportError: raise errors.ExtensionNotFound(name) - def load_extension(self, name, *, package=None): + def load_extension(self, name: str, *, package: Optional[str] = None) -> None: """Loads an extension. An extension is a python module that contains commands, cogs, or @@ -705,7 +730,7 @@ class BotBase(GroupMixin): self._load_from_module_spec(spec, name) - def unload_extension(self, name, *, package=None): + def unload_extension(self, name: str, *, package: Optional[str] = None) -> None: """Unloads an extension. When the extension is unloaded, all commands, listeners, and cogs are @@ -746,7 +771,7 @@ class BotBase(GroupMixin): self._remove_module_references(lib.__name__) self._call_module_finalizers(lib, name) - def reload_extension(self, name, *, package=None): + def reload_extension(self, name: str, *, package: Optional[str] = None) -> None: """Atomically reloads an extension. This replaces the extension with the same extension, only refreshed. This is @@ -802,7 +827,7 @@ class BotBase(GroupMixin): # if the load failed, the remnants should have been # cleaned from the load_extension function call # so let's load it from our old compiled library. - lib.setup(self) + lib.setup(self) # type: ignore self.__extensions[name] = lib # revert sys.modules back to normal and raise back to caller @@ -810,18 +835,18 @@ class BotBase(GroupMixin): raise @property - def extensions(self): + def extensions(self) -> Mapping[str, types.ModuleType]: """Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension.""" return types.MappingProxyType(self.__extensions) # help command stuff @property - def help_command(self): + def help_command(self) -> Optional[HelpCommand]: return self._help_command @help_command.setter - def help_command(self, value): + def help_command(self, value: Optional[HelpCommand]) -> None: if value is not None: if not isinstance(value, HelpCommand): raise TypeError('help_command must be a subclass of HelpCommand') @@ -837,7 +862,7 @@ class BotBase(GroupMixin): # command processing - async def get_prefix(self, message): + async def get_prefix(self, message: Message) -> Union[List[str], str]: """|coro| Retrieves the prefix the bot is listening to @@ -875,7 +900,7 @@ class BotBase(GroupMixin): return ret - async def get_context(self, message, *, cls=Context): + async def get_context(self, message: Message, *, cls: Type[CXT] = Context) -> CXT: r"""|coro| Returns the invocation context from the message. @@ -908,7 +933,7 @@ class BotBase(GroupMixin): view = StringView(message.content) ctx = cls(prefix=None, view=view, bot=self, message=message) - if message.author.id == self.user.id: + if message.author.id == self.user.id: # type: ignore return ctx prefix = await self.get_prefix(message) @@ -945,11 +970,12 @@ class BotBase(GroupMixin): invoker = view.get_word() ctx.invoked_with = invoker - ctx.prefix = invoked_prefix + # type-checker fails to narrow invoked_prefix type. + ctx.prefix = invoked_prefix # type: ignore ctx.command = self.all_commands.get(invoker) return ctx - async def invoke(self, ctx): + async def invoke(self, ctx: Context) -> None: """|coro| Invokes the command given under the invocation context and @@ -975,7 +1001,7 @@ class BotBase(GroupMixin): exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found') self.dispatch('command_error', ctx, exc) - async def process_commands(self, message): + async def process_commands(self, message: Message) -> None: """|coro| This function processes the commands that have been registered diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index da428cffd..9931557db 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -21,15 +21,30 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations import inspect +import discord.utils + +from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type + from ._types import _BaseCommand +if TYPE_CHECKING: + from .bot import BotBase + from .context import Context + from .core import Command + __all__ = ( 'CogMeta', 'Cog', ) +CogT = TypeVar('CogT', bound='Cog') +FuncT = TypeVar('FuncT', bound=Callable[..., Any]) + +MISSING: Any = discord.utils.MISSING + class CogMeta(type): """A metaclass for defining a cog. @@ -89,8 +104,12 @@ class CogMeta(type): async def bar(self, ctx): pass # hidden -> False """ + __cog_name__: str + __cog_settings__: Dict[str, Any] + __cog_commands__: List[Command] + __cog_listeners__: List[Tuple[str, str]] - def __new__(cls, *args, **kwargs): + def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: name, bases, attrs = args attrs['__cog_name__'] = kwargs.pop('name', name) attrs['__cog_settings__'] = kwargs.pop('command_attrs', {}) @@ -143,14 +162,14 @@ class CogMeta(type): new_cls.__cog_listeners__ = listeners_as_list return new_cls - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args) @classmethod - def qualified_name(cls): + def qualified_name(cls) -> str: return cls.__cog_name__ -def _cog_special_method(func): +def _cog_special_method(func: FuncT) -> FuncT: func.__cog_special_method__ = None return func @@ -164,8 +183,12 @@ class Cog(metaclass=CogMeta): When inheriting from this class, the options shown in :class:`CogMeta` are equally valid here. """ + __cog_name__: ClassVar[str] + __cog_settings__: ClassVar[Dict[str, Any]] + __cog_commands__: ClassVar[List[Command]] + __cog_listeners__: ClassVar[List[Tuple[str, str]]] - def __new__(cls, *args, **kwargs): + def __new__(cls: Type[CogT], *args: Any, **kwargs: Any) -> CogT: # For issue 426, we need to store a copy of the command objects # since we modify them to inject `self` to them. # To do this, we need to interfere with the Cog creation process. @@ -173,7 +196,8 @@ class Cog(metaclass=CogMeta): cmd_attrs = cls.__cog_settings__ # Either update the command with the cog provided defaults or copy it. - self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) + # r.e type ignore, type-checker complains about overriding a ClassVar + self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore lookup = { cmd.qualified_name: cmd @@ -186,15 +210,15 @@ class Cog(metaclass=CogMeta): parent = command.parent if parent is not None: # Get the latest parent reference - parent = lookup[parent.qualified_name] + parent = lookup[parent.qualified_name] # type: ignore # Update our parent's reference to our self - parent.remove_command(command.name) - parent.add_command(command) + parent.remove_command(command.name) # type: ignore + parent.add_command(command) # type: ignore return self - def get_commands(self): + def get_commands(self) -> List[Command]: r""" Returns -------- @@ -209,20 +233,20 @@ class Cog(metaclass=CogMeta): return [c for c in self.__cog_commands__ if c.parent is None] @property - def qualified_name(self): + def qualified_name(self) -> str: """:class:`str`: Returns the cog's specified name, not the class name.""" return self.__cog_name__ @property - def description(self): + def description(self) -> str: """:class:`str`: Returns the cog's description, typically the cleaned docstring.""" return self.__cog_description__ @description.setter - def description(self, description): + def description(self, description: str) -> None: self.__cog_description__ = description - def walk_commands(self): + def walk_commands(self) -> Generator[Command, None, None]: """An iterator that recursively walks through this cog's commands and subcommands. Yields @@ -237,7 +261,7 @@ class Cog(metaclass=CogMeta): if isinstance(command, GroupMixin): yield from command.walk_commands() - def get_listeners(self): + def get_listeners(self) -> List[Tuple[str, Callable[..., Any]]]: """Returns a :class:`list` of (name, function) listener pairs that are defined in this cog. Returns @@ -248,12 +272,12 @@ class Cog(metaclass=CogMeta): return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__] @classmethod - def _get_overridden_method(cls, method): + def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]: """Return None if the method is not overridden. Otherwise returns the overridden method.""" return getattr(method.__func__, '__cog_special_method__', method) @classmethod - def listener(cls, name=None): + def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]: """A decorator that marks a function as a listener. This is the cog equivalent of :meth:`.Bot.listen`. @@ -271,10 +295,10 @@ class Cog(metaclass=CogMeta): the name. """ - if name is not None and not isinstance(name, str): + if name is not MISSING and not isinstance(name, str): raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.') - def decorator(func): + def decorator(func: FuncT) -> FuncT: actual = func if isinstance(actual, staticmethod): actual = actual.__func__ @@ -293,7 +317,7 @@ class Cog(metaclass=CogMeta): return func return decorator - def has_error_handler(self): + def has_error_handler(self) -> bool: """:class:`bool`: Checks whether the cog has an error handler. .. versionadded:: 1.7 @@ -301,7 +325,7 @@ class Cog(metaclass=CogMeta): return not hasattr(self.cog_command_error.__func__, '__cog_special_method__') @_cog_special_method - def cog_unload(self): + def cog_unload(self) -> None: """A special method that is called when the cog gets removed. This function **cannot** be a coroutine. It must be a regular @@ -312,7 +336,7 @@ class Cog(metaclass=CogMeta): pass @_cog_special_method - def bot_check_once(self, ctx): + def bot_check_once(self, ctx: Context) -> bool: """A special method that registers as a :meth:`.Bot.check_once` check. @@ -322,7 +346,7 @@ class Cog(metaclass=CogMeta): return True @_cog_special_method - def bot_check(self, ctx): + def bot_check(self, ctx: Context) -> bool: """A special method that registers as a :meth:`.Bot.check` check. @@ -332,7 +356,7 @@ class Cog(metaclass=CogMeta): return True @_cog_special_method - def cog_check(self, ctx): + def cog_check(self, ctx: Context) -> bool: """A special method that registers as a :func:`~discord.ext.commands.check` for every command and subcommand in this cog. @@ -342,7 +366,7 @@ class Cog(metaclass=CogMeta): return True @_cog_special_method - async def cog_command_error(self, ctx, error): + async def cog_command_error(self, ctx: Context, error: Exception) -> None: """A special method that is called whenever an error is dispatched inside this cog. @@ -361,7 +385,7 @@ class Cog(metaclass=CogMeta): pass @_cog_special_method - async def cog_before_invoke(self, ctx): + async def cog_before_invoke(self, ctx: Context) -> None: """A special method that acts as a cog local pre-invoke hook. This is similar to :meth:`.Command.before_invoke`. @@ -376,7 +400,7 @@ class Cog(metaclass=CogMeta): pass @_cog_special_method - async def cog_after_invoke(self, ctx): + async def cog_after_invoke(self, ctx: Context) -> None: """A special method that acts as a cog local post-invoke hook. This is similar to :meth:`.Command.after_invoke`. @@ -390,7 +414,7 @@ class Cog(metaclass=CogMeta): """ pass - def _inject(self, bot): + def _inject(self: CogT, bot: BotBase) -> CogT: cls = self.__class__ # realistically, the only thing that can cause loading errors @@ -425,7 +449,7 @@ class Cog(metaclass=CogMeta): return self - def _eject(self, bot): + def _eject(self, bot: BotBase) -> None: cls = self.__class__ try: diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index c5367c24f..e231f0e70 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -21,16 +21,52 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + +import inspect +import re + +from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union import discord.abc import discord.utils -import re + +from discord.message import Message + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + from discord.abc import MessageableChannel + from discord.guild import Guild + from discord.member import Member + from discord.state import ConnectionState + from discord.user import ClientUser, User + from discord.voice_client import VoiceProtocol + + from .bot import Bot, AutoShardedBot + from .cog import Cog + from .core import Command + from .help import HelpCommand + from .view import StringView __all__ = ( 'Context', ) -class Context(discord.abc.Messageable): +MISSING: Any = discord.utils.MISSING + + +T = TypeVar('T') +BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]") +CogT = TypeVar('CogT', bound="Cog") + +if TYPE_CHECKING: + P = ParamSpec('P') +else: + P = TypeVar('P') + + +class Context(discord.abc.Messageable, Generic[BotT]): r"""Represents the context in which a command is being invoked under. This class contains a lot of meta data to help you understand more about @@ -58,11 +94,11 @@ class Context(discord.abc.Messageable): This is only of use for within converters. .. versionadded:: 2.0 - prefix: :class:`str` + prefix: Optional[:class:`str`] The prefix that was used to invoke the command. - command: :class:`Command` + command: Optional[:class:`Command`] The command that is being invoked currently. - invoked_with: :class:`str` + invoked_with: Optional[:class:`str`] The command name that triggered this invocation. Useful for finding out which alias called the command. invoked_parents: List[:class:`str`] @@ -73,7 +109,7 @@ class Context(discord.abc.Messageable): .. versionadded:: 1.7 - invoked_subcommand: :class:`Command` + invoked_subcommand: Optional[:class:`Command`] The subcommand that was invoked. If no valid subcommand was invoked then this is equal to ``None``. subcommand_passed: Optional[:class:`str`] @@ -86,23 +122,38 @@ class Context(discord.abc.Messageable): or invoked. """ - def __init__(self, **attrs): - self.message = attrs.pop('message', None) - self.bot = attrs.pop('bot', None) - self.args = attrs.pop('args', []) - self.kwargs = attrs.pop('kwargs', {}) - self.prefix = attrs.pop('prefix') - self.command = attrs.pop('command', None) - self.view = attrs.pop('view', None) - self.invoked_with = attrs.pop('invoked_with', None) - self.invoked_parents = attrs.pop('invoked_parents', []) - self.invoked_subcommand = attrs.pop('invoked_subcommand', None) - self.subcommand_passed = attrs.pop('subcommand_passed', None) - self.command_failed = attrs.pop('command_failed', False) - self.current_parameter = attrs.pop('current_parameter', None) - self._state = self.message._state - - async def invoke(self, command, /, *args, **kwargs): + def __init__(self, + *, + message: Message, + bot: BotT, + view: StringView, + args: List[Any] = MISSING, + kwargs: Dict[str, Any] = MISSING, + prefix: Optional[str] = None, + command: Optional[Command] = None, + invoked_with: Optional[str] = None, + invoked_parents: List[str] = MISSING, + invoked_subcommand: Optional[Command] = None, + subcommand_passed: Optional[str] = None, + command_failed: bool = False, + current_parameter: Optional[inspect.Parameter] = None, + ): + self.message: Message = message + self.bot: BotT = bot + self.args: List[Any] = args or [] + self.kwargs: Dict[str, Any] = kwargs or {} + self.prefix: Optional[str] = prefix + self.command: Optional[Command] = command + self.view: StringView = view + self.invoked_with: Optional[str] = invoked_with + self.invoked_parents: List[str] = invoked_parents or [] + self.invoked_subcommand: Optional[Command] = invoked_subcommand + self.subcommand_passed: Optional[str] = subcommand_passed + self.command_failed: bool = command_failed + self.current_parameter: Optional[inspect.Parameter] = current_parameter + self._state: ConnectionState = self.message._state + + async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: r"""|coro| Calls a command with the arguments given. @@ -133,17 +184,9 @@ class Context(discord.abc.Messageable): TypeError The command argument to invoke is missing. """ - arguments = [] - if command.cog is not None: - arguments.append(command.cog) - - arguments.append(self) - arguments.extend(args) + return await command(self, *args, **kwargs) - ret = await command.callback(*arguments, **kwargs) - return ret - - async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True): + async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> None: """|coro| Calls the command again. @@ -187,7 +230,7 @@ class Context(discord.abc.Messageable): if restart: to_call = cmd.root_parent or cmd - view.index = len(self.prefix) + view.index = len(self.prefix or '') view.previous = 0 self.invoked_parents = [] self.invoked_with = view.get_word() # advance to get the root command @@ -206,20 +249,23 @@ class Context(discord.abc.Messageable): self.subcommand_passed = subcommand_passed @property - def valid(self): + def valid(self) -> bool: """:class:`bool`: Checks if the invocation context is valid to be invoked with.""" return self.prefix is not None and self.command is not None - async def _get_channel(self): + async def _get_channel(self) -> discord.abc.Messageable: return self.channel @property - def clean_prefix(self): + def clean_prefix(self) -> str: """:class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``. .. versionadded:: 2.0 """ - user = self.guild.me if self.guild else self.bot.user + if self.prefix is None: + return '' + + user = self.me # this breaks if the prefix mention is not the bot itself but I # consider this to be an *incredibly* strange use case. I'd rather go # for this common use case rather than waste performance for the @@ -228,7 +274,7 @@ class Context(discord.abc.Messageable): return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix) @property - def cog(self): + def cog(self) -> Optional[Cog]: """Optional[:class:`.Cog`]: Returns the cog associated with this context's command. None if it does not exist.""" if self.command is None: @@ -236,38 +282,39 @@ class Context(discord.abc.Messageable): return self.command.cog @discord.utils.cached_property - def guild(self): + def guild(self) -> Optional[Guild]: """Optional[:class:`.Guild`]: Returns the guild associated with this context's command. None if not available.""" return self.message.guild @discord.utils.cached_property - def channel(self): + def channel(self) -> MessageableChannel: """Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command. Shorthand for :attr:`.Message.channel`. """ return self.message.channel @discord.utils.cached_property - def author(self): + def author(self) -> Union[User, Member]: """Union[:class:`~discord.User`, :class:`.Member`]: Returns the author associated with this context's command. Shorthand for :attr:`.Message.author` """ return self.message.author @discord.utils.cached_property - def me(self): + def me(self) -> Union[Member, ClientUser]: """Union[:class:`.Member`, :class:`.ClientUser`]: Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message contexts. """ - return self.guild.me if self.guild is not None else self.bot.user + # bot.user will never be None at this point. + return self.guild.me if self.guild is not None else self.bot.user # type: ignore @property - def voice_client(self): + def voice_client(self) -> Optional[VoiceProtocol]: r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" g = self.guild return g.voice_client if g else None - async def send_help(self, *args): + async def send_help(self, *args: Any) -> Any: """send_help(entity=) |coro| @@ -319,12 +366,12 @@ class Context(discord.abc.Messageable): return None entity = args[0] - if entity is None: - return None - if isinstance(entity, str): entity = bot.get_cog(entity) or bot.get_command(entity) + if entity is None: + return None + try: entity.qualified_name except AttributeError: @@ -348,6 +395,6 @@ class Context(discord.abc.Messageable): except CommandError as e: await cmd.on_help_command_error(self, e) - @discord.utils.copy_doc(discord.Message.reply) - async def reply(self, content=None, **kwargs): + @discord.utils.copy_doc(Message.reply) + async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message: return await self.message.reply(content, **kwargs) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index abf883260..88e65507d 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -21,19 +21,29 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations from typing import ( Any, Callable, Dict, + Generator, + Generic, Literal, + List, + Optional, Union, + Set, + Tuple, + TypeVar, + Type, + TYPE_CHECKING, + overload, ) import asyncio import functools import inspect import datetime -import types import discord @@ -42,6 +52,22 @@ from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, Dy from .converter import run_converters, get_converter, Greedy from ._types import _BaseCommand from .cog import Cog +from .context import Context + + +if TYPE_CHECKING: + from typing_extensions import Concatenate, ParamSpec, TypeGuard + + from discord.message import Message + + from ._types import ( + Coro, + CoroFunc, + Check, + Hook, + Error, + ) + __all__ = ( 'Command', @@ -70,6 +96,22 @@ __all__ = ( 'bot_has_guild_permissions' ) +MISSING: Any = discord.utils.MISSING + +T = TypeVar('T') +CogT = TypeVar('CogT', bound='Cog') +CommandT = TypeVar('CommandT', bound='Command') +ContextT = TypeVar('ContextT', bound='Context') +# CHT = TypeVar('CHT', bound='Check') +GroupT = TypeVar('GroupT', bound='Group') +HookT = TypeVar('HookT', bound='Hook') +ErrorT = TypeVar('ErrorT', bound='Error') + +if TYPE_CHECKING: + P = ParamSpec('P') +else: + P = TypeVar('P') + def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: partial = functools.partial while True: @@ -160,7 +202,7 @@ class _CaseInsensitiveDict(dict): def __setitem__(self, k, v): super().__setitem__(k.casefold(), v) -class Command(_BaseCommand): +class Command(_BaseCommand, Generic[CogT, P, T]): r"""A class that implements the protocol for a bot text command. These are not created manually, instead they are created via the @@ -172,7 +214,7 @@ class Command(_BaseCommand): The name of the command. callback: :ref:`coroutine ` The coroutine that is executed when the command is called. - help: :class:`str` + help: Optional[:class:`str`] The long help text for the command. brief: Optional[:class:`str`] The short help text for the command. @@ -235,8 +277,9 @@ class Command(_BaseCommand): .. versionadded:: 2.0 """ + __original_kwargs__: Dict[str, Any] - def __new__(cls, *args, **kwargs): + def __new__(cls: Type[CommandT], *args: Any, **kwargs: Any) -> CommandT: # if you're wondering why this is done, it's because we need to ensure # we have a complete original copy of **kwargs even for classes that # mess with it by popping before delegating to the subclass __init__. @@ -252,16 +295,20 @@ class Command(_BaseCommand): self.__original_kwargs__ = kwargs.copy() return self - def __init__(self, func, **kwargs): + def __init__(self, func: Union[ + Callable[Concatenate[CogT, ContextT, P], Coro[T]], + Callable[Concatenate[ContextT, P], Coro[T]], + ], **kwargs: Any): if not asyncio.iscoroutinefunction(func): raise TypeError('Callback must be a coroutine.') - self.name = name = kwargs.get('name') or func.__name__ + name = kwargs.get('name') or func.__name__ if not isinstance(name, str): raise TypeError('Name of a command must be a string.') + self.name: str = name self.callback = func - self.enabled = kwargs.get('enabled', True) + self.enabled: bool = kwargs.get('enabled', True) help_doc = kwargs.get('help') if help_doc is not None: @@ -271,74 +318,85 @@ class Command(_BaseCommand): if isinstance(help_doc, bytes): help_doc = help_doc.decode('utf-8') - self.help = help_doc + self.help: Optional[str] = help_doc - self.brief = kwargs.get('brief') - self.usage = kwargs.get('usage') - self.rest_is_raw = kwargs.get('rest_is_raw', False) - self.aliases = kwargs.get('aliases', []) - self.extras = kwargs.get('extras', {}) + self.brief: Optional[str] = kwargs.get('brief') + self.usage: Optional[str] = kwargs.get('usage') + self.rest_is_raw: bool = kwargs.get('rest_is_raw', False) + self.aliases: Union[List[str], Tuple[str]] = kwargs.get('aliases', []) + self.extras: Dict[str, Any] = kwargs.get('extras', {}) if not isinstance(self.aliases, (list, tuple)): raise TypeError("Aliases of a command must be a list or a tuple of strings.") - self.description = inspect.cleandoc(kwargs.get('description', '')) - self.hidden = kwargs.get('hidden', False) + self.description: str = inspect.cleandoc(kwargs.get('description', '')) + self.hidden: bool = kwargs.get('hidden', False) try: checks = func.__commands_checks__ checks.reverse() except AttributeError: checks = kwargs.get('checks', []) - finally: - self.checks = checks + + self.checks: List[Check] = checks try: cooldown = func.__commands_cooldown__ except AttributeError: cooldown = kwargs.get('cooldown') - finally: - if cooldown is None: - self._buckets = CooldownMapping(cooldown, BucketType.default) - elif isinstance(cooldown, CooldownMapping): - self._buckets = cooldown + + if cooldown is None: + buckets = CooldownMapping(cooldown, BucketType.default) + elif isinstance(cooldown, CooldownMapping): + buckets = cooldown + else: + raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") + self._buckets: CooldownMapping = buckets try: max_concurrency = func.__commands_max_concurrency__ except AttributeError: max_concurrency = kwargs.get('max_concurrency') - finally: - self._max_concurrency = max_concurrency - self.require_var_positional = kwargs.get('require_var_positional', False) - self.ignore_extra = kwargs.get('ignore_extra', True) - self.cooldown_after_parsing = kwargs.get('cooldown_after_parsing', False) - self.cog = None + self._max_concurrency: Optional[MaxConcurrency] = max_concurrency + + self.require_var_positional: bool = kwargs.get('require_var_positional', False) + self.ignore_extra: bool = kwargs.get('ignore_extra', True) + self.cooldown_after_parsing: bool = kwargs.get('cooldown_after_parsing', False) + self.cog: Optional[CogT] = None # bandaid for the fact that sometimes parent can be the bot instance parent = kwargs.get('parent') - self.parent = parent if isinstance(parent, _BaseCommand) else None + self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore + self._before_invoke: Optional[Hook] = None try: before_invoke = func.__before_invoke__ except AttributeError: - self._before_invoke = None + pass else: self.before_invoke(before_invoke) + self._after_invoke: Optional[Hook] = None try: after_invoke = func.__after_invoke__ except AttributeError: - self._after_invoke = None + pass else: self.after_invoke(after_invoke) @property - def callback(self): + def callback(self) -> Union[ + Callable[Concatenate[CogT, Context, P], Coro[T]], + Callable[Concatenate[Context, P], Coro[T]], + ]: return self._callback @callback.setter - def callback(self, function): + def callback(self, function: Union[ + Callable[Concatenate[CogT, Context, P], Coro[T]], + Callable[Concatenate[Context, P], Coro[T]], + ]) -> None: self._callback = function unwrap = unwrap_function(function) self.module = unwrap.__module__ @@ -350,7 +408,7 @@ class Command(_BaseCommand): self.params = get_signature_parameters(function, globalns) - def add_check(self, func): + def add_check(self, func: Check) -> None: """Adds a check to the command. This is the non-decorator interface to :func:`.check`. @@ -365,7 +423,7 @@ class Command(_BaseCommand): self.checks.append(func) - def remove_check(self, func): + def remove_check(self, func: Check) -> None: """Removes a check from the command. This function is idempotent and will not raise an exception @@ -384,8 +442,8 @@ class Command(_BaseCommand): except ValueError: pass - def update(self, **kwargs): - """Updates :class:`Command` instance with updated attributes. + def update(self, **kwargs: Any) -> None: + """Updates :class:`Command` instance with updated attribute. This works similarly to the :func:`.command` decorator in terms of parameters in that they are passed to the :class:`Command` or @@ -393,7 +451,7 @@ class Command(_BaseCommand): """ self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs)) - async def __call__(self, *args, **kwargs): + async def __call__(self, context: Context, *args: P.args, **kwargs: P.kwargs) -> T: """|coro| Calls the internal callback that the command holds. @@ -407,11 +465,11 @@ class Command(_BaseCommand): .. versionadded:: 1.3 """ if self.cog is not None: - return await self.callback(self.cog, *args, **kwargs) + return await self.callback(self.cog, context, *args, **kwargs) # type: ignore else: - return await self.callback(*args, **kwargs) + return await self.callback(context, *args, **kwargs) # type: ignore - def _ensure_assignment_on_copy(self, other): + def _ensure_assignment_on_copy(self, other: CommandT) -> CommandT: other._before_invoke = self._before_invoke other._after_invoke = self._after_invoke if self.checks != other.checks: @@ -419,7 +477,8 @@ class Command(_BaseCommand): if self._buckets.valid and not other._buckets.valid: other._buckets = self._buckets.copy() if self._max_concurrency != other._max_concurrency: - other._max_concurrency = self._max_concurrency.copy() + # _max_concurrency won't be None at this point + other._max_concurrency = self._max_concurrency.copy() # type: ignore try: other.on_error = self.on_error @@ -427,7 +486,7 @@ class Command(_BaseCommand): pass return other - def copy(self): + def copy(self: CommandT) -> CommandT: """Creates a copy of this command. Returns @@ -438,7 +497,7 @@ class Command(_BaseCommand): ret = self.__class__(self.callback, **self.__original_kwargs__) return self._ensure_assignment_on_copy(ret) - def _update_copy(self, kwargs): + def _update_copy(self: CommandT, kwargs: Dict[str, Any]) -> CommandT: if kwargs: kw = kwargs.copy() kw.update(self.__original_kwargs__) @@ -447,7 +506,7 @@ class Command(_BaseCommand): else: return self.copy() - async def dispatch_error(self, ctx, error): + async def dispatch_error(self, ctx: Context, error: Exception) -> None: ctx.command_failed = True cog = self.cog try: @@ -470,7 +529,7 @@ class Command(_BaseCommand): finally: ctx.bot.dispatch('command_error', ctx, error) - async def transform(self, ctx, param): + async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: required = param.default is param.empty converter = get_converter(param) consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw @@ -508,9 +567,10 @@ class Command(_BaseCommand): argument = view.get_quoted_word() view.previous = previous - return await run_converters(ctx, converter, argument, param) + # type-checker fails to narrow argument + return await run_converters(ctx, converter, argument, param) # type: ignore - async def _transform_greedy_pos(self, ctx, param, required, converter): + async def _transform_greedy_pos(self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any) -> Any: view = ctx.view result = [] while not view.eof: @@ -520,7 +580,7 @@ class Command(_BaseCommand): view.skip_ws() try: argument = view.get_quoted_word() - value = await run_converters(ctx, converter, argument, param) + value = await run_converters(ctx, converter, argument, param) # type: ignore except (CommandError, ArgumentParsingError): view.index = previous break @@ -531,12 +591,12 @@ class Command(_BaseCommand): return param.default return result - async def _transform_greedy_var_pos(self, ctx, param, converter): + async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any: view = ctx.view previous = view.index try: argument = view.get_quoted_word() - value = await run_converters(ctx, converter, argument, param) + value = await run_converters(ctx, converter, argument, param) # type: ignore except (CommandError, ArgumentParsingError): view.index = previous raise RuntimeError() from None # break loop @@ -567,7 +627,7 @@ class Command(_BaseCommand): return result @property - def full_parent_name(self): + def full_parent_name(self) -> str: """:class:`str`: Retrieves the fully qualified parent command name. This the base command name required to execute it. For example, @@ -575,14 +635,15 @@ class Command(_BaseCommand): """ entries = [] command = self - while command.parent is not None: - command = command.parent - entries.append(command.name) + # command.parent is type-hinted as GroupMixin some attributes are resolved via MRO + while command.parent is not None: # type: ignore + command = command.parent # type: ignore + entries.append(command.name) # type: ignore return ' '.join(reversed(entries)) @property - def parents(self): + def parents(self) -> List[Group]: """List[:class:`Group`]: Retrieves the parents of this command. If the command has no parents then it returns an empty :class:`list`. @@ -593,14 +654,14 @@ class Command(_BaseCommand): """ entries = [] command = self - while command.parent is not None: - command = command.parent + while command.parent is not None: # type: ignore + command = command.parent # type: ignore entries.append(command) return entries @property - def root_parent(self): + def root_parent(self) -> Optional[Group]: """Optional[:class:`Group`]: Retrieves the root parent of this command. If the command has no parents then it returns ``None``. @@ -612,7 +673,7 @@ class Command(_BaseCommand): return self.parents[-1] @property - def qualified_name(self): + def qualified_name(self) -> str: """:class:`str`: Retrieves the fully qualified command name. This is the full parent name with the command name as well. @@ -626,10 +687,10 @@ class Command(_BaseCommand): else: return self.name - def __str__(self): + def __str__(self) -> str: return self.qualified_name - async def _parse_arguments(self, ctx): + async def _parse_arguments(self, ctx: Context) -> None: ctx.args = [ctx] if self.cog is None else [self.cog, ctx] ctx.kwargs = {} args = ctx.args @@ -679,7 +740,7 @@ class Command(_BaseCommand): if not self.ignore_extra and not view.eof: raise TooManyArguments('Too many arguments passed to ' + self.qualified_name) - async def call_before_hooks(self, ctx): + async def call_before_hooks(self, ctx: Context) -> None: # now that we're done preparing we can call the pre-command hooks # first, call the command local hook: cog = self.cog @@ -689,9 +750,9 @@ class Command(_BaseCommand): # __self__ only exists for methods, not functions # however, if @command.before_invoke is used, it will be a function if instance: - await self._before_invoke(instance, ctx) + await self._before_invoke(instance, ctx) # type: ignore else: - await self._before_invoke(ctx) + await self._before_invoke(ctx) # type: ignore # call the cog local hook if applicable: if cog is not None: @@ -704,14 +765,14 @@ class Command(_BaseCommand): if hook is not None: await hook(ctx) - async def call_after_hooks(self, ctx): + async def call_after_hooks(self, ctx: Context) -> None: cog = self.cog if self._after_invoke is not None: instance = getattr(self._after_invoke, '__self__', cog) if instance: - await self._after_invoke(instance, ctx) + await self._after_invoke(instance, ctx) # type: ignore else: - await self._after_invoke(ctx) + await self._after_invoke(ctx) # type: ignore # call the cog local hook if applicable: if cog is not None: @@ -723,7 +784,7 @@ class Command(_BaseCommand): if hook is not None: await hook(ctx) - def _prepare_cooldowns(self, ctx): + def _prepare_cooldowns(self, ctx: Context) -> None: if self._buckets.valid: dt = ctx.message.edited_at or ctx.message.created_at current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() @@ -731,16 +792,17 @@ class Command(_BaseCommand): if bucket is not None: retry_after = bucket.update_rate_limit(current) if retry_after: - raise CommandOnCooldown(bucket, retry_after, self._buckets.type) + raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore - async def prepare(self, ctx): + async def prepare(self, ctx: Context) -> None: ctx.command = self if not await self.can_run(ctx): raise CheckFailure(f'The check functions for command {self.qualified_name} failed.') if self._max_concurrency is not None: - await self._max_concurrency.acquire(ctx) + # For this application, context can be duck-typed as a Message + await self._max_concurrency.acquire(ctx) # type: ignore try: if self.cooldown_after_parsing: @@ -753,10 +815,10 @@ class Command(_BaseCommand): await self.call_before_hooks(ctx) except: if self._max_concurrency is not None: - await self._max_concurrency.release(ctx) + await self._max_concurrency.release(ctx) # type: ignore raise - def is_on_cooldown(self, ctx): + def is_on_cooldown(self, ctx: Context) -> bool: """Checks whether the command is currently on cooldown. Parameters @@ -777,7 +839,7 @@ class Command(_BaseCommand): current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() return bucket.get_tokens(current) == 0 - def reset_cooldown(self, ctx): + def reset_cooldown(self, ctx: Context) -> None: """Resets the cooldown on this command. Parameters @@ -789,7 +851,7 @@ class Command(_BaseCommand): bucket = self._buckets.get_bucket(ctx.message) bucket.reset() - def get_cooldown_retry_after(self, ctx): + def get_cooldown_retry_after(self, ctx: Context) -> float: """Retrieves the amount of seconds before this command can be tried again. .. versionadded:: 1.4 @@ -813,7 +875,7 @@ class Command(_BaseCommand): return 0.0 - async def invoke(self, ctx): + async def invoke(self, ctx: Context) -> None: await self.prepare(ctx) # terminate the invoked_subcommand chain. @@ -824,7 +886,7 @@ class Command(_BaseCommand): injected = hooked_wrapped_callback(self, ctx, self.callback) await injected(*ctx.args, **ctx.kwargs) - async def reinvoke(self, ctx, *, call_hooks=False): + async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: ctx.command = self await self._parse_arguments(ctx) @@ -833,7 +895,7 @@ class Command(_BaseCommand): ctx.invoked_subcommand = None try: - await self.callback(*ctx.args, **ctx.kwargs) + await self.callback(*ctx.args, **ctx.kwargs) # type: ignore except: ctx.command_failed = True raise @@ -841,7 +903,7 @@ class Command(_BaseCommand): if call_hooks: await self.call_after_hooks(ctx) - def error(self, coro): + def error(self, coro: ErrorT) -> ErrorT: """A decorator that registers a coroutine as a local error handler. A local error handler is an :func:`.on_command_error` event limited to @@ -862,17 +924,17 @@ class Command(_BaseCommand): if not asyncio.iscoroutinefunction(coro): raise TypeError('The error handler must be a coroutine.') - self.on_error = coro + self.on_error: Error = coro return coro - def has_error_handler(self): + def has_error_handler(self) -> bool: """:class:`bool`: Checks whether the command has an error handler registered. .. versionadded:: 1.7 """ return hasattr(self, 'on_error') - def before_invoke(self, coro): + def before_invoke(self, coro: HookT) -> HookT: """A decorator that registers a coroutine as a pre-invoke hook. A pre-invoke hook is called directly before the command is @@ -899,7 +961,7 @@ class Command(_BaseCommand): self._before_invoke = coro return coro - def after_invoke(self, coro): + def after_invoke(self, coro: HookT) -> HookT: """A decorator that registers a coroutine as a post-invoke hook. A post-invoke hook is called directly after the command is @@ -927,12 +989,12 @@ class Command(_BaseCommand): return coro @property - def cog_name(self): + def cog_name(self) -> Optional[str]: """Optional[:class:`str`]: The name of the cog this command belongs to, if any.""" return type(self.cog).__cog_name__ if self.cog is not None else None @property - def short_doc(self): + def short_doc(self) -> str: """:class:`str`: Gets the "short" documentation of a command. By default, this is the :attr:`.brief` attribute. @@ -945,11 +1007,11 @@ class Command(_BaseCommand): return self.help.split('\n', 1)[0] return '' - def _is_typing_optional(self, annotation): - return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__ + def _is_typing_optional(self, annotation: Union[T, Optional[T]]) -> TypeGuard[Optional[T]]: + return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__ # type: ignore @property - def signature(self): + def signature(self) -> str: """:class:`str`: Returns a POSIX-like signature useful for help command output.""" if self.usage is not None: return self.usage @@ -1002,7 +1064,7 @@ class Command(_BaseCommand): return ' '.join(result) - async def can_run(self, ctx): + async def can_run(self, ctx: Context) -> bool: """|coro| Checks if the command can be executed by checking all the predicates @@ -1052,7 +1114,7 @@ class Command(_BaseCommand): # since we have no checks, then we just return True. return True - return await discord.utils.async_all(predicate(ctx) for predicate in predicates) + return await discord.utils.async_all(predicate(ctx) for predicate in predicates) # type: ignore finally: ctx.command = original @@ -1068,24 +1130,24 @@ class GroupMixin: case_insensitive: :class:`bool` Whether the commands should be case insensitive. Defaults to ``False``. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: case_insensitive = kwargs.get('case_insensitive', False) - self.all_commands = _CaseInsensitiveDict() if case_insensitive else {} - self.case_insensitive = case_insensitive + self.all_commands: Dict[str, Command] = _CaseInsensitiveDict() if case_insensitive else {} + self.case_insensitive: bool = case_insensitive super().__init__(*args, **kwargs) @property - def commands(self): + def commands(self) -> Set[Command]: """Set[:class:`.Command`]: A unique set of commands without aliases that are registered.""" return set(self.all_commands.values()) - def recursively_remove_all_commands(self): + def recursively_remove_all_commands(self) -> None: for command in self.all_commands.copy().values(): if isinstance(command, GroupMixin): command.recursively_remove_all_commands() self.remove_command(command.name) - def add_command(self, command): + def add_command(self, command: Command) -> None: """Adds a :class:`.Command` into the internal list of commands. This is usually not called, instead the :meth:`~.GroupMixin.command` or @@ -1123,7 +1185,7 @@ class GroupMixin: raise CommandRegistrationError(alias, alias_conflict=True) self.all_commands[alias] = command - def remove_command(self, name): + def remove_command(self, name: str) -> Optional[Command]: """Remove a :class:`.Command` from the internal list of commands. @@ -1156,11 +1218,11 @@ class GroupMixin: # in the case of a CommandRegistrationError, an alias might conflict # with an already existing command. If this is the case, we want to # make sure the pre-existing command is not removed. - if cmd not in (None, command): + if cmd is not None and cmd != command: self.all_commands[alias] = cmd return command - def walk_commands(self): + def walk_commands(self) -> Generator[Command, None, None]: """An iterator that recursively walks through all commands and subcommands. .. versionchanged:: 1.4 @@ -1176,7 +1238,7 @@ class GroupMixin: if isinstance(command, GroupMixin): yield from command.walk_commands() - def get_command(self, name): + def get_command(self, name: str) -> Optional[Command]: """Get a :class:`.Command` from the internal list of commands. @@ -1210,13 +1272,39 @@ class GroupMixin: for name in names[1:]: try: - obj = obj.all_commands[name] + obj = obj.all_commands[name] # type: ignore except (AttributeError, KeyError): return None return obj - def command(self, *args, **kwargs): + @overload + def command( + self, + name: str = ..., + cls: Type[Command[CogT, P, T]] = ..., + *args: Any, + **kwargs: Any, + ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[T]]], Command[CogT, P, T]]: + ... + + @overload + def command( + self, + name: str = ..., + cls: Type[CommandT] = ..., + *args: Any, + **kwargs: Any, + ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], CommandT]: + ... + + def command( + self, + name: str = MISSING, + cls: Type[CommandT] = MISSING, + *args: Any, + **kwargs: Any, + ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], CommandT]: """A shortcut decorator that invokes :func:`.command` and adds it to the internal command list via :meth:`~.GroupMixin.add_command`. @@ -1225,15 +1313,41 @@ class GroupMixin: Callable[..., :class:`Command`] A decorator that converts the provided method into a Command, adds it to the bot, then returns it. """ - def decorator(func): + def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> CommandT: kwargs.setdefault('parent', self) - result = command(*args, **kwargs)(func) + result = command(name=name, cls=cls, *args, **kwargs)(func) self.add_command(result) return result return decorator - def group(self, *args, **kwargs): + @overload + def group( + self, + name: str = ..., + cls: Type[Group[CogT, P, T]] = ..., + *args: Any, + **kwargs: Any, + ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[T]]], Group[CogT, P, T]]: + ... + + @overload + def group( + self, + name: str = ..., + cls: Type[GroupT] = ..., + *args: Any, + **kwargs: Any, + ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], GroupT]: + ... + + def group( + self, + name: str = MISSING, + cls: Type[GroupT] = MISSING, + *args: Any, + **kwargs: Any, + ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], GroupT]: """A shortcut decorator that invokes :func:`.group` and adds it to the internal command list via :meth:`~.GroupMixin.add_command`. @@ -1242,15 +1356,15 @@ class GroupMixin: Callable[..., :class:`Group`] A decorator that converts the provided method into a Group, adds it to the bot, then returns it. """ - def decorator(func): + def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> GroupT: kwargs.setdefault('parent', self) - result = group(*args, **kwargs)(func) + result = group(name=name, cls=cls, *args, **kwargs)(func) self.add_command(result) return result return decorator -class Group(GroupMixin, Command): +class Group(GroupMixin, Command[CogT, P, T]): """A class that implements a grouping protocol for commands to be executed as subcommands. @@ -1272,11 +1386,11 @@ class Group(GroupMixin, Command): Indicates if the group's commands should be case insensitive. Defaults to ``False``. """ - def __init__(self, *args, **attrs): - self.invoke_without_command = attrs.pop('invoke_without_command', False) + def __init__(self, *args: Any, **attrs: Any) -> None: + self.invoke_without_command: bool = attrs.pop('invoke_without_command', False) super().__init__(*args, **attrs) - def copy(self): + def copy(self: GroupT) -> GroupT: """Creates a copy of this :class:`Group`. Returns @@ -1287,9 +1401,9 @@ class Group(GroupMixin, Command): ret = super().copy() for cmd in self.commands: ret.add_command(cmd.copy()) - return ret + return ret # type: ignore - async def invoke(self, ctx): + async def invoke(self, ctx: Context) -> None: ctx.invoked_subcommand = None ctx.subcommand_passed = None early_invoke = not self.invoke_without_command @@ -1309,7 +1423,7 @@ class Group(GroupMixin, Command): injected = hooked_wrapped_callback(self, ctx, self.callback) await injected(*ctx.args, **ctx.kwargs) - ctx.invoked_parents.append(ctx.invoked_with) + ctx.invoked_parents.append(ctx.invoked_with) # type: ignore if trigger and ctx.invoked_subcommand: ctx.invoked_with = trigger @@ -1320,7 +1434,7 @@ class Group(GroupMixin, Command): view.previous = previous await super().invoke(ctx) - async def reinvoke(self, ctx, *, call_hooks=False): + async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: ctx.invoked_subcommand = None early_invoke = not self.invoke_without_command if early_invoke: @@ -1341,7 +1455,7 @@ class Group(GroupMixin, Command): if early_invoke: try: - await self.callback(*ctx.args, **ctx.kwargs) + await self.callback(*ctx.args, **ctx.kwargs) # type: ignore except: ctx.command_failed = True raise @@ -1349,7 +1463,7 @@ class Group(GroupMixin, Command): if call_hooks: await self.call_after_hooks(ctx) - ctx.invoked_parents.append(ctx.invoked_with) + ctx.invoked_parents.append(ctx.invoked_with) # type: ignore if trigger and ctx.invoked_subcommand: ctx.invoked_with = trigger @@ -1362,7 +1476,48 @@ class Group(GroupMixin, Command): # Decorators -def command(name=None, cls=None, **attrs): +@overload +def command( + name: str = ..., + cls: Type[Command[CogT, P, T]] = ..., + **attrs: Any, +) -> Callable[ + [ + Union[ + Callable[Concatenate[CogT, ContextT, P], Coro[T]], + Callable[Concatenate[ContextT, P], Coro[T]], + ] + ] +, Command[CogT, P, T]]: + ... + +@overload +def command( + name: str = ..., + cls: Type[CommandT] = ..., + **attrs: Any, +) -> Callable[ + [ + Union[ + Callable[Concatenate[CogT, ContextT, P], Coro[Any]], + Callable[Concatenate[ContextT, P], Coro[Any]], + ] + ] +, CommandT]: + ... + +def command( + name: str = MISSING, + cls: Type[CommandT] = MISSING, + **attrs: Any +) -> Callable[ + [ + Union[ + Callable[Concatenate[ContextT, P], Coro[Any]], + Callable[Concatenate[CogT, ContextT, P], Coro[T]], + ] + ] +, Union[Command[CogT, P, T], CommandT]]: """A decorator that transforms a function into a :class:`.Command` or if called with :func:`.group`, :class:`.Group`. @@ -1392,17 +1547,61 @@ def command(name=None, cls=None, **attrs): TypeError If the function is not a coroutine or is already a command. """ - if cls is None: - cls = Command + if cls is MISSING: + cls = Command # type: ignore - def decorator(func): + def decorator(func: Union[ + Callable[Concatenate[ContextT, P], Coro[Any]], + Callable[Concatenate[CogT, ContextT, P], Coro[Any]], + ]) -> CommandT: if isinstance(func, Command): raise TypeError('Callback is already a command.') return cls(func, name=name, **attrs) return decorator -def group(name=None, **attrs): +@overload +def group( + name: str = ..., + cls: Type[Group[CogT, P, T]] = ..., + **attrs: Any, +) -> Callable[ + [ + Union[ + Callable[Concatenate[CogT, ContextT, P], Coro[T]], + Callable[Concatenate[ContextT, P], Coro[T]], + ] + ] +, Group[CogT, P, T]]: + ... + +@overload +def group( + name: str = ..., + cls: Type[GroupT] = ..., + **attrs: Any, +) -> Callable[ + [ + Union[ + Callable[Concatenate[CogT, ContextT, P], Coro[Any]], + Callable[Concatenate[ContextT, P], Coro[Any]], + ] + ] +, GroupT]: + ... + +def group( + name: str = MISSING, + cls: Type[GroupT] = MISSING, + **attrs: Any, +) -> Callable[ + [ + Union[ + Callable[Concatenate[ContextT, P], Coro[Any]], + Callable[Concatenate[CogT, ContextT, P], Coro[T]], + ] + ] +, Union[Group[CogT, P, T], GroupT]]: """A decorator that transforms a function into a :class:`.Group`. This is similar to the :func:`.command` decorator but the ``cls`` @@ -1411,11 +1610,11 @@ def group(name=None, **attrs): .. versionchanged:: 1.1 The ``cls`` parameter can now be passed. """ + if cls is MISSING: + cls = Group # type: ignore + return command(name=name, cls=cls, **attrs) # type: ignore - attrs.setdefault('cls', Group) - return command(name=name, **attrs) - -def check(predicate): +def check(predicate: Check) -> Callable[[T], T]: r"""A decorator that adds a check to the :class:`.Command` or its subclasses. These checks could be accessed via :attr:`.Command.checks`. @@ -1486,7 +1685,7 @@ def check(predicate): The predicate to check if the command should be invoked. """ - def decorator(func): + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: if isinstance(func, Command): func.checks.append(predicate) else: @@ -1502,12 +1701,12 @@ def check(predicate): else: @functools.wraps(predicate) async def wrapper(ctx): - return predicate(ctx) + return predicate(ctx) # type: ignore decorator.predicate = wrapper - return decorator + return decorator # type: ignore -def check_any(*checks): +def check_any(*checks: Check) -> Callable[[T], T]: r"""A :func:`check` that is added that checks if any of the checks passed will pass, i.e. using logical OR. @@ -1560,7 +1759,7 @@ def check_any(*checks): else: unwrapped.append(pred) - async def predicate(ctx): + async def predicate(ctx: Context) -> bool: errors = [] for func in unwrapped: try: @@ -1575,7 +1774,7 @@ def check_any(*checks): return check(predicate) -def has_role(item): +def has_role(item: Union[int, str]) -> Callable[[T], T]: """A :func:`.check` that is added that checks if the member invoking the command has the role specified via the name or ID specified. @@ -1602,21 +1801,22 @@ def has_role(item): The name or ID of the role to check. """ - def predicate(ctx): + def predicate(ctx: Context) -> bool: if ctx.guild is None: raise NoPrivateMessage() + # ctx.guild is None doesn't narrow ctx.author to Member if isinstance(item, int): - role = discord.utils.get(ctx.author.roles, id=item) + role = discord.utils.get(ctx.author.roles, id=item) # type: ignore else: - role = discord.utils.get(ctx.author.roles, name=item) + role = discord.utils.get(ctx.author.roles, name=item) # type: ignore if role is None: raise MissingRole(item) return True return check(predicate) -def has_any_role(*items): +def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: r"""A :func:`.check` that is added that checks if the member invoking the command has **any** of the roles specified. This means that if they have one out of the three roles specified, then this check will return `True`. @@ -1651,14 +1851,15 @@ def has_any_role(*items): if ctx.guild is None: raise NoPrivateMessage() - getter = functools.partial(discord.utils.get, ctx.author.roles) + # ctx.guild is None doesn't narrow ctx.author to Member + getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items): return True - raise MissingAnyRole(items) + raise MissingAnyRole(list(items)) return check(predicate) -def bot_has_role(item): +def bot_has_role(item: int) -> Callable[[T], T]: """Similar to :func:`.has_role` except checks if the bot itself has the role. @@ -1686,7 +1887,7 @@ def bot_has_role(item): return True return check(predicate) -def bot_has_any_role(*items): +def bot_has_any_role(*items: int) -> Callable[[T], T]: """Similar to :func:`.has_any_role` except checks if the bot itself has any of the roles listed. @@ -1707,10 +1908,10 @@ def bot_has_any_role(*items): getter = functools.partial(discord.utils.get, me.roles) if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items): return True - raise BotMissingAnyRole(items) + raise BotMissingAnyRole(list(items)) return check(predicate) -def has_permissions(**perms): +def has_permissions(**perms: bool) -> Callable[[T], T]: """A :func:`.check` that is added that checks if the member has all of the permissions necessary. @@ -1744,9 +1945,9 @@ def has_permissions(**perms): if invalid: raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") - def predicate(ctx): + def predicate(ctx: Context) -> bool: ch = ctx.channel - permissions = ch.permissions_for(ctx.author) + permissions = ch.permissions_for(ctx.author) # type: ignore missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] @@ -1757,7 +1958,7 @@ def has_permissions(**perms): return check(predicate) -def bot_has_permissions(**perms): +def bot_has_permissions(**perms: bool) -> Callable[[T], T]: """Similar to :func:`.has_permissions` except checks if the bot itself has the permissions listed. @@ -1769,10 +1970,10 @@ def bot_has_permissions(**perms): if invalid: raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") - def predicate(ctx): + def predicate(ctx: Context) -> bool: guild = ctx.guild me = guild.me if guild is not None else ctx.bot.user - permissions = ctx.channel.permissions_for(me) + permissions = ctx.channel.permissions_for(me) # type: ignore missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] @@ -1783,7 +1984,7 @@ def bot_has_permissions(**perms): return check(predicate) -def has_guild_permissions(**perms): +def has_guild_permissions(**perms: bool) -> Callable[[T], T]: """Similar to :func:`.has_permissions`, but operates on guild wide permissions instead of the current channel permissions. @@ -1797,11 +1998,11 @@ def has_guild_permissions(**perms): if invalid: raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") - def predicate(ctx): + def predicate(ctx: Context) -> bool: if not ctx.guild: raise NoPrivateMessage - permissions = ctx.author.guild_permissions + permissions = ctx.author.guild_permissions # type: ignore missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: @@ -1811,7 +2012,7 @@ def has_guild_permissions(**perms): return check(predicate) -def bot_has_guild_permissions(**perms): +def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: """Similar to :func:`.has_guild_permissions`, but checks the bot members guild permissions. @@ -1822,11 +2023,11 @@ def bot_has_guild_permissions(**perms): if invalid: raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") - def predicate(ctx): + def predicate(ctx: Context) -> bool: if not ctx.guild: raise NoPrivateMessage - permissions = ctx.me.guild_permissions + permissions = ctx.me.guild_permissions # type: ignore missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: @@ -1836,7 +2037,7 @@ def bot_has_guild_permissions(**perms): return check(predicate) -def dm_only(): +def dm_only() -> Callable[[T], T]: """A :func:`.check` that indicates this command must only be used in a DM context. Only private messages are allowed when using the command. @@ -1847,14 +2048,14 @@ def dm_only(): .. versionadded:: 1.1 """ - def predicate(ctx): + def predicate(ctx: Context) -> bool: if ctx.guild is not None: raise PrivateMessageOnly() return True return check(predicate) -def guild_only(): +def guild_only() -> Callable[[T], T]: """A :func:`.check` that indicates this command must only be used in a guild context only. Basically, no private messages are allowed when using the command. @@ -1863,14 +2064,14 @@ def guild_only(): that is inherited from :exc:`.CheckFailure`. """ - def predicate(ctx): + def predicate(ctx: Context) -> bool: if ctx.guild is None: raise NoPrivateMessage() return True return check(predicate) -def is_owner(): +def is_owner() -> Callable[[T], T]: """A :func:`.check` that checks if the person invoking this command is the owner of the bot. @@ -1880,14 +2081,14 @@ def is_owner(): from :exc:`.CheckFailure`. """ - async def predicate(ctx): + async def predicate(ctx: Context) -> bool: if not await ctx.bot.is_owner(ctx.author): raise NotOwner('You do not own this bot.') return True return check(predicate) -def is_nsfw(): +def is_nsfw() -> Callable[[T], T]: """A :func:`.check` that checks if the channel is a NSFW channel. This check raises a special exception, :exc:`.NSFWChannelRequired` @@ -1898,14 +2099,14 @@ def is_nsfw(): Raise :exc:`.NSFWChannelRequired` instead of generic :exc:`.CheckFailure`. DM channels will also now pass this check. """ - def pred(ctx): + def pred(ctx: Context) -> bool: ch = ctx.channel if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()): return True - raise NSFWChannelRequired(ch) + raise NSFWChannelRequired(ch) # type: ignore return check(pred) -def cooldown(rate, per, type=BucketType.default): +def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default) -> Callable[[T], T]: """A decorator that adds a cooldown to a :class:`.Command` A cooldown allows a command to only be used a specific amount @@ -1932,15 +2133,15 @@ def cooldown(rate, per, type=BucketType.default): Callables are now supported for custom bucket types. """ - def decorator(func): + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: if isinstance(func, Command): func._buckets = CooldownMapping(Cooldown(rate, per), type) else: func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type) return func - return decorator + return decorator # type: ignore -def dynamic_cooldown(cooldown, type=BucketType.default): +def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default) -> Callable[[T], T]: """A decorator that adds a dynamic cooldown to a :class:`.Command` This differs from :func:`.cooldown` in that it takes a function that @@ -1972,15 +2173,15 @@ def dynamic_cooldown(cooldown, type=BucketType.default): if not callable(cooldown): raise TypeError("A callable must be provided") - def decorator(func): + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: if isinstance(func, Command): func._buckets = DynamicCooldownMapping(cooldown, type) else: func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type) return func - return decorator + return decorator # type: ignore -def max_concurrency(number, per=BucketType.default, *, wait=False): +def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]: """A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses. This enables you to only allow a certain number of command invocations at the same time, @@ -2004,16 +2205,16 @@ def max_concurrency(number, per=BucketType.default, *, wait=False): then the command waits until it can be executed. """ - def decorator(func): + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: value = MaxConcurrency(number, per=per, wait=wait) if isinstance(func, Command): func._max_concurrency = value else: func.__commands_max_concurrency__ = value return func - return decorator + return decorator # type: ignore -def before_invoke(coro): +def before_invoke(coro) -> Callable[[T], T]: """A decorator that registers a coroutine as a pre-invoke hook. This allows you to refer to one before invoke hook for several commands that @@ -2051,15 +2252,15 @@ def before_invoke(coro): bot.add_cog(What()) """ - def decorator(func): + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: if isinstance(func, Command): func.before_invoke(coro) else: func.__before_invoke__ = coro return func - return decorator + return decorator # type: ignore -def after_invoke(coro): +def after_invoke(coro) -> Callable[[T], T]: """A decorator that registers a coroutine as a post-invoke hook. This allows you to refer to one after invoke hook for several commands that @@ -2067,10 +2268,10 @@ def after_invoke(coro): .. versionadded:: 1.4 """ - def decorator(func): + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: if isinstance(func, Command): func.after_invoke(coro) else: func.__after_invoke__ = coro return func - return decorator + return decorator # type: ignore diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index 6de81bb5f..6a70726d3 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -27,11 +27,17 @@ import copy import functools import inspect import re + +from typing import Optional, TYPE_CHECKING + import discord.utils from .core import Group, Command from .errors import CommandError +if TYPE_CHECKING: + from .context import Context + __all__ = ( 'Paginator', 'HelpCommand', @@ -320,7 +326,7 @@ class HelpCommand: self.command_attrs = attrs = options.pop('command_attrs', {}) attrs.setdefault('name', 'help') attrs.setdefault('help', 'Shows this message') - self.context = None + self.context: Optional[Context] = None self._command_impl = _HelpCommandImpl(self, **self.command_attrs) def copy(self):