Browse Source

[commands][types] Type hint commands-ext

pull/7441/head
Josh 4 years ago
committed by GitHub
parent
commit
f3cb197429
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 20
      discord/ext/commands/_types.py
  2. 132
      discord/ext/commands/bot.py
  3. 82
      discord/ext/commands/cog.py
  4. 149
      discord/ext/commands/context.py
  5. 557
      discord/ext/commands/core.py
  6. 8
      discord/ext/commands/help.py

20
discord/ext/commands/_types.py

@ -22,6 +22,26 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union
if TYPE_CHECKING:
from .context import Context
from .cog import Cog
from .errors import CommandError
T = TypeVar('T')
Coro = Coroutine[Any, Any, T]
MaybeCoro = Union[T, Coro[T]]
CoroFunc = Callable[..., Coro[Any]]
Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]]
# This is merely a tag type to avoid circular import issues. # This is merely a tag type to avoid circular import issues.
# Yes, this is a terrible solution but ultimately it is the only solution. # Yes, this is a terrible solution but ultimately it is the only solution.
class _BaseCommand: class _BaseCommand:

132
discord/ext/commands/bot.py

@ -22,13 +22,18 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import asyncio import asyncio
import collections import collections
import collections.abc
import inspect import inspect
import importlib.util import importlib.util
import sys import sys
import traceback import traceback
import types import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
import discord import discord
@ -39,6 +44,15 @@ from . import errors
from .help import HelpCommand, DefaultHelpCommand from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog from .cog import Cog
if TYPE_CHECKING:
import importlib.machinery
from discord.message import Message
from ._types import (
Check,
CoroFunc,
)
__all__ = ( __all__ = (
'when_mentioned', 'when_mentioned',
'when_mentioned_or', 'when_mentioned_or',
@ -46,14 +60,21 @@ __all__ = (
'AutoShardedBot', 'AutoShardedBot',
) )
def when_mentioned(bot, msg): MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context')
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned. """A callable that implements a command prefix equivalent to being mentioned.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
""" """
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # bot.user will never be None when this is called
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
def when_mentioned_or(*prefixes): def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided. """A callable that implements when mentioned or other prefixes provided.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@ -89,7 +110,7 @@ def when_mentioned_or(*prefixes):
return inner return inner
def _is_submodule(parent, child): def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".") return parent == child or child.startswith(parent + ".")
class _DefaultRepr: class _DefaultRepr:
@ -102,10 +123,10 @@ class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, **options): def __init__(self, command_prefix, help_command=_default, description=None, **options):
super().__init__(**options) super().__init__(**options)
self.command_prefix = command_prefix self.command_prefix = command_prefix
self.extra_events = {} self.extra_events: Dict[str, List[CoroFunc]] = {}
self.__cogs = {} self.__cogs: Dict[str, Cog] = {}
self.__extensions = {} self.__extensions: Dict[str, types.ModuleType] = {}
self._checks = [] self._checks: List[Check] = []
self._check_once = [] self._check_once = []
self._before_invoke = None self._before_invoke = None
self._after_invoke = None self._after_invoke = None
@ -128,13 +149,14 @@ class BotBase(GroupMixin):
# internal helpers # internal helpers
def dispatch(self, event_name, *args, **kwargs): def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
super().dispatch(event_name, *args, **kwargs) # super() will resolve to Client
super().dispatch(event_name, *args, **kwargs) # type: ignore
ev = 'on_' + event_name ev = 'on_' + event_name
for event in self.extra_events.get(ev, []): for event in self.extra_events.get(ev, []):
self._schedule_event(event, ev, *args, **kwargs) self._schedule_event(event, ev, *args, **kwargs) # type: ignore
async def close(self): async def close(self) -> None:
for extension in tuple(self.__extensions): for extension in tuple(self.__extensions):
try: try:
self.unload_extension(extension) self.unload_extension(extension)
@ -147,9 +169,9 @@ class BotBase(GroupMixin):
except Exception: except Exception:
pass pass
await super().close() await super().close() # type: ignore
async def on_command_error(self, context, exception): async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
"""|coro| """|coro|
The default command error handler provided by the bot. The default command error handler provided by the bot.
@ -175,7 +197,7 @@ class BotBase(GroupMixin):
# global check registration # global check registration
def check(self, func): def check(self, func: T) -> T:
r"""A decorator that adds a global check to the bot. r"""A decorator that adds a global check to the bot.
A global check is similar to a :func:`.check` that is applied A global check is similar to a :func:`.check` that is applied
@ -200,10 +222,11 @@ class BotBase(GroupMixin):
return ctx.command.qualified_name in allowed_commands return ctx.command.qualified_name in allowed_commands
""" """
self.add_check(func) # T was used instead of Check to ensure the type matches on return
self.add_check(func) # type: ignore
return func return func
def add_check(self, func, *, call_once=False): def add_check(self, func: Check, *, call_once: bool = False) -> None:
"""Adds a global check to the bot. """Adds a global check to the bot.
This is the non-decorator interface to :meth:`.check` This is the non-decorator interface to :meth:`.check`
@ -223,7 +246,7 @@ class BotBase(GroupMixin):
else: else:
self._checks.append(func) self._checks.append(func)
def remove_check(self, func, *, call_once=False): def remove_check(self, func: Check, *, call_once: bool = False) -> None:
"""Removes a global check from the bot. """Removes a global check from the bot.
This function is idempotent and will not raise an exception This function is idempotent and will not raise an exception
@ -244,7 +267,7 @@ class BotBase(GroupMixin):
except ValueError: except ValueError:
pass pass
def check_once(self, func): def check_once(self, func: CFT) -> CFT:
r"""A decorator that adds a "call once" global check to the bot. r"""A decorator that adds a "call once" global check to the bot.
Unlike regular global checks, this one is called only once Unlike regular global checks, this one is called only once
@ -282,15 +305,16 @@ class BotBase(GroupMixin):
self.add_check(func, call_once=True) self.add_check(func, call_once=True)
return func return func
async def can_run(self, ctx, *, call_once=False): async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool:
data = self._check_once if call_once else self._checks data = self._check_once if call_once else self._checks
if len(data) == 0: if len(data) == 0:
return True return True
return await discord.utils.async_all(f(ctx) for f in data) # type-checker doesn't distinguish between functions and methods
return await discord.utils.async_all(f(ctx) for f in data) # type: ignore
async def is_owner(self, user): async def is_owner(self, user: discord.User) -> bool:
"""|coro| """|coro|
Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of
@ -319,7 +343,8 @@ class BotBase(GroupMixin):
elif self.owner_ids: elif self.owner_ids:
return user.id in self.owner_ids return user.id in self.owner_ids
else: else:
app = await self.application_info()
app = await self.application_info() # type: ignore
if app.team: if app.team:
self.owner_ids = ids = {m.id for m in app.team.members} self.owner_ids = ids = {m.id for m in app.team.members}
return user.id in ids return user.id in ids
@ -327,7 +352,7 @@ class BotBase(GroupMixin):
self.owner_id = owner_id = app.owner.id self.owner_id = owner_id = app.owner.id
return user.id == owner_id return user.id == owner_id
def before_invoke(self, coro): def before_invoke(self, coro: CFT) -> CFT:
"""A decorator that registers a coroutine as a pre-invoke hook. """A decorator that registers a coroutine as a pre-invoke hook.
A pre-invoke hook is called directly before the command is A pre-invoke hook is called directly before the command is
@ -359,7 +384,7 @@ class BotBase(GroupMixin):
self._before_invoke = coro self._before_invoke = coro
return coro return coro
def after_invoke(self, coro): def after_invoke(self, coro: CFT) -> CFT:
r"""A decorator that registers a coroutine as a post-invoke hook. r"""A decorator that registers a coroutine as a post-invoke hook.
A post-invoke hook is called directly after the command is A post-invoke hook is called directly after the command is
@ -394,14 +419,14 @@ class BotBase(GroupMixin):
# listener registration # listener registration
def add_listener(self, func, name=None): def add_listener(self, func: CoroFunc, name: str = MISSING) -> None:
"""The non decorator alternative to :meth:`.listen`. """The non decorator alternative to :meth:`.listen`.
Parameters Parameters
----------- -----------
func: :ref:`coroutine <coroutine>` func: :ref:`coroutine <coroutine>`
The function to call. The function to call.
name: Optional[:class:`str`] name: :class:`str`
The name of the event to listen for. Defaults to ``func.__name__``. The name of the event to listen for. Defaults to ``func.__name__``.
Example Example
@ -416,7 +441,7 @@ class BotBase(GroupMixin):
bot.add_listener(my_message, 'on_message') bot.add_listener(my_message, 'on_message')
""" """
name = func.__name__ if name is None else name name = func.__name__ if name is MISSING else name
if not asyncio.iscoroutinefunction(func): if not asyncio.iscoroutinefunction(func):
raise TypeError('Listeners must be coroutines') raise TypeError('Listeners must be coroutines')
@ -426,7 +451,7 @@ class BotBase(GroupMixin):
else: else:
self.extra_events[name] = [func] self.extra_events[name] = [func]
def remove_listener(self, func, name=None): def remove_listener(self, func: CoroFunc, name: str = MISSING) -> None:
"""Removes a listener from the pool of listeners. """Removes a listener from the pool of listeners.
Parameters Parameters
@ -438,7 +463,7 @@ class BotBase(GroupMixin):
``func.__name__``. ``func.__name__``.
""" """
name = func.__name__ if name is None else name name = func.__name__ if name is MISSING else name
if name in self.extra_events: if name in self.extra_events:
try: try:
@ -446,7 +471,7 @@ class BotBase(GroupMixin):
except ValueError: except ValueError:
pass pass
def listen(self, name=None): def listen(self, name: str = MISSING) -> Callable[[CFT], CFT]:
"""A decorator that registers another function as an external """A decorator that registers another function as an external
event listener. Basically this allows you to listen to multiple event listener. Basically this allows you to listen to multiple
events from different places e.g. such as :func:`.on_ready` events from different places e.g. such as :func:`.on_ready`
@ -476,7 +501,7 @@ class BotBase(GroupMixin):
The function being listened to is not a coroutine. The function being listened to is not a coroutine.
""" """
def decorator(func): def decorator(func: CFT) -> CFT:
self.add_listener(func, name) self.add_listener(func, name)
return func return func
@ -528,7 +553,7 @@ class BotBase(GroupMixin):
cog = cog._inject(self) cog = cog._inject(self)
self.__cogs[cog_name] = cog self.__cogs[cog_name] = cog
def get_cog(self, name): def get_cog(self, name: str) -> Optional[Cog]:
"""Gets the cog instance requested. """Gets the cog instance requested.
If the cog is not found, ``None`` is returned instead. If the cog is not found, ``None`` is returned instead.
@ -547,7 +572,7 @@ class BotBase(GroupMixin):
""" """
return self.__cogs.get(name) return self.__cogs.get(name)
def remove_cog(self, name): def remove_cog(self, name: str) -> Optional[Cog]:
"""Removes a cog from the bot and returns it. """Removes a cog from the bot and returns it.
All registered commands and event listeners that the All registered commands and event listeners that the
@ -578,13 +603,13 @@ class BotBase(GroupMixin):
return cog return cog
@property @property
def cogs(self): def cogs(self) -> Mapping[str, Cog]:
"""Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog.""" """Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog."""
return types.MappingProxyType(self.__cogs) return types.MappingProxyType(self.__cogs)
# extensions # extensions
def _remove_module_references(self, name): def _remove_module_references(self, name: str) -> None:
# find all references to the module # find all references to the module
# remove the cogs registered from the module # remove the cogs registered from the module
for cogname, cog in self.__cogs.copy().items(): for cogname, cog in self.__cogs.copy().items():
@ -608,7 +633,7 @@ class BotBase(GroupMixin):
for index in reversed(remove): for index in reversed(remove):
del event_list[index] del event_list[index]
def _call_module_finalizers(self, lib, key): def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
try: try:
func = getattr(lib, 'teardown') func = getattr(lib, 'teardown')
except AttributeError: except AttributeError:
@ -626,12 +651,12 @@ class BotBase(GroupMixin):
if _is_submodule(name, module): if _is_submodule(name, module):
del sys.modules[module] del sys.modules[module]
def _load_from_module_spec(self, spec, key): def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
# precondition: key not in self.__extensions # precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec) lib = importlib.util.module_from_spec(spec)
sys.modules[key] = lib sys.modules[key] = lib
try: try:
spec.loader.exec_module(lib) spec.loader.exec_module(lib) # type: ignore
except Exception as e: except Exception as e:
del sys.modules[key] del sys.modules[key]
raise errors.ExtensionFailed(key, e) from e raise errors.ExtensionFailed(key, e) from e
@ -652,13 +677,13 @@ class BotBase(GroupMixin):
else: else:
self.__extensions[key] = lib self.__extensions[key] = lib
def _resolve_name(self, name, package): def _resolve_name(self, name: str, package: Optional[str]) -> str:
try: try:
return importlib.util.resolve_name(name, package) return importlib.util.resolve_name(name, package)
except ImportError: except ImportError:
raise errors.ExtensionNotFound(name) raise errors.ExtensionNotFound(name)
def load_extension(self, name, *, package=None): def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Loads an extension. """Loads an extension.
An extension is a python module that contains commands, cogs, or An extension is a python module that contains commands, cogs, or
@ -705,7 +730,7 @@ class BotBase(GroupMixin):
self._load_from_module_spec(spec, name) self._load_from_module_spec(spec, name)
def unload_extension(self, name, *, package=None): def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Unloads an extension. """Unloads an extension.
When the extension is unloaded, all commands, listeners, and cogs are When the extension is unloaded, all commands, listeners, and cogs are
@ -746,7 +771,7 @@ class BotBase(GroupMixin):
self._remove_module_references(lib.__name__) self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name) self._call_module_finalizers(lib, name)
def reload_extension(self, name, *, package=None): def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Atomically reloads an extension. """Atomically reloads an extension.
This replaces the extension with the same extension, only refreshed. This is This replaces the extension with the same extension, only refreshed. This is
@ -802,7 +827,7 @@ class BotBase(GroupMixin):
# if the load failed, the remnants should have been # if the load failed, the remnants should have been
# cleaned from the load_extension function call # cleaned from the load_extension function call
# so let's load it from our old compiled library. # so let's load it from our old compiled library.
lib.setup(self) lib.setup(self) # type: ignore
self.__extensions[name] = lib self.__extensions[name] = lib
# revert sys.modules back to normal and raise back to caller # revert sys.modules back to normal and raise back to caller
@ -810,18 +835,18 @@ class BotBase(GroupMixin):
raise raise
@property @property
def extensions(self): def extensions(self) -> Mapping[str, types.ModuleType]:
"""Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension.""" """Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension."""
return types.MappingProxyType(self.__extensions) return types.MappingProxyType(self.__extensions)
# help command stuff # help command stuff
@property @property
def help_command(self): def help_command(self) -> Optional[HelpCommand]:
return self._help_command return self._help_command
@help_command.setter @help_command.setter
def help_command(self, value): def help_command(self, value: Optional[HelpCommand]) -> None:
if value is not None: if value is not None:
if not isinstance(value, HelpCommand): if not isinstance(value, HelpCommand):
raise TypeError('help_command must be a subclass of HelpCommand') raise TypeError('help_command must be a subclass of HelpCommand')
@ -837,7 +862,7 @@ class BotBase(GroupMixin):
# command processing # command processing
async def get_prefix(self, message): async def get_prefix(self, message: Message) -> Union[List[str], str]:
"""|coro| """|coro|
Retrieves the prefix the bot is listening to Retrieves the prefix the bot is listening to
@ -875,7 +900,7 @@ class BotBase(GroupMixin):
return ret return ret
async def get_context(self, message, *, cls=Context): async def get_context(self, message: Message, *, cls: Type[CXT] = Context) -> CXT:
r"""|coro| r"""|coro|
Returns the invocation context from the message. Returns the invocation context from the message.
@ -908,7 +933,7 @@ class BotBase(GroupMixin):
view = StringView(message.content) view = StringView(message.content)
ctx = cls(prefix=None, view=view, bot=self, message=message) ctx = cls(prefix=None, view=view, bot=self, message=message)
if message.author.id == self.user.id: if message.author.id == self.user.id: # type: ignore
return ctx return ctx
prefix = await self.get_prefix(message) prefix = await self.get_prefix(message)
@ -945,11 +970,12 @@ class BotBase(GroupMixin):
invoker = view.get_word() invoker = view.get_word()
ctx.invoked_with = invoker ctx.invoked_with = invoker
ctx.prefix = invoked_prefix # type-checker fails to narrow invoked_prefix type.
ctx.prefix = invoked_prefix # type: ignore
ctx.command = self.all_commands.get(invoker) ctx.command = self.all_commands.get(invoker)
return ctx return ctx
async def invoke(self, ctx): async def invoke(self, ctx: Context) -> None:
"""|coro| """|coro|
Invokes the command given under the invocation context and Invokes the command given under the invocation context and
@ -975,7 +1001,7 @@ class BotBase(GroupMixin):
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found') exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
self.dispatch('command_error', ctx, exc) self.dispatch('command_error', ctx, exc)
async def process_commands(self, message): async def process_commands(self, message: Message) -> None:
"""|coro| """|coro|
This function processes the commands that have been registered This function processes the commands that have been registered

82
discord/ext/commands/cog.py

@ -21,15 +21,30 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import inspect import inspect
import discord.utils
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type
from ._types import _BaseCommand from ._types import _BaseCommand
if TYPE_CHECKING:
from .bot import BotBase
from .context import Context
from .core import Command
__all__ = ( __all__ = (
'CogMeta', 'CogMeta',
'Cog', 'Cog',
) )
CogT = TypeVar('CogT', bound='Cog')
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
MISSING: Any = discord.utils.MISSING
class CogMeta(type): class CogMeta(type):
"""A metaclass for defining a cog. """A metaclass for defining a cog.
@ -89,8 +104,12 @@ class CogMeta(type):
async def bar(self, ctx): async def bar(self, ctx):
pass # hidden -> False pass # hidden -> False
""" """
__cog_name__: str
__cog_settings__: Dict[str, Any]
__cog_commands__: List[Command]
__cog_listeners__: List[Tuple[str, str]]
def __new__(cls, *args, **kwargs): def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
name, bases, attrs = args name, bases, attrs = args
attrs['__cog_name__'] = kwargs.pop('name', name) attrs['__cog_name__'] = kwargs.pop('name', name)
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {}) attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
@ -143,14 +162,14 @@ class CogMeta(type):
new_cls.__cog_listeners__ = listeners_as_list new_cls.__cog_listeners__ = listeners_as_list
return new_cls return new_cls
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args) super().__init__(*args)
@classmethod @classmethod
def qualified_name(cls): def qualified_name(cls) -> str:
return cls.__cog_name__ return cls.__cog_name__
def _cog_special_method(func): def _cog_special_method(func: FuncT) -> FuncT:
func.__cog_special_method__ = None func.__cog_special_method__ = None
return func return func
@ -164,8 +183,12 @@ class Cog(metaclass=CogMeta):
When inheriting from this class, the options shown in :class:`CogMeta` When inheriting from this class, the options shown in :class:`CogMeta`
are equally valid here. are equally valid here.
""" """
__cog_name__: ClassVar[str]
__cog_settings__: ClassVar[Dict[str, Any]]
__cog_commands__: ClassVar[List[Command]]
__cog_listeners__: ClassVar[List[Tuple[str, str]]]
def __new__(cls, *args, **kwargs): def __new__(cls: Type[CogT], *args: Any, **kwargs: Any) -> CogT:
# For issue 426, we need to store a copy of the command objects # For issue 426, we need to store a copy of the command objects
# since we modify them to inject `self` to them. # since we modify them to inject `self` to them.
# To do this, we need to interfere with the Cog creation process. # To do this, we need to interfere with the Cog creation process.
@ -173,7 +196,8 @@ class Cog(metaclass=CogMeta):
cmd_attrs = cls.__cog_settings__ cmd_attrs = cls.__cog_settings__
# Either update the command with the cog provided defaults or copy it. # Either update the command with the cog provided defaults or copy it.
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # r.e type ignore, type-checker complains about overriding a ClassVar
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
lookup = { lookup = {
cmd.qualified_name: cmd cmd.qualified_name: cmd
@ -186,15 +210,15 @@ class Cog(metaclass=CogMeta):
parent = command.parent parent = command.parent
if parent is not None: if parent is not None:
# Get the latest parent reference # Get the latest parent reference
parent = lookup[parent.qualified_name] parent = lookup[parent.qualified_name] # type: ignore
# Update our parent's reference to our self # Update our parent's reference to our self
parent.remove_command(command.name) parent.remove_command(command.name) # type: ignore
parent.add_command(command) parent.add_command(command) # type: ignore
return self return self
def get_commands(self): def get_commands(self) -> List[Command]:
r""" r"""
Returns Returns
-------- --------
@ -209,20 +233,20 @@ class Cog(metaclass=CogMeta):
return [c for c in self.__cog_commands__ if c.parent is None] return [c for c in self.__cog_commands__ if c.parent is None]
@property @property
def qualified_name(self): def qualified_name(self) -> str:
""":class:`str`: Returns the cog's specified name, not the class name.""" """:class:`str`: Returns the cog's specified name, not the class name."""
return self.__cog_name__ return self.__cog_name__
@property @property
def description(self): def description(self) -> str:
""":class:`str`: Returns the cog's description, typically the cleaned docstring.""" """:class:`str`: Returns the cog's description, typically the cleaned docstring."""
return self.__cog_description__ return self.__cog_description__
@description.setter @description.setter
def description(self, description): def description(self, description: str) -> None:
self.__cog_description__ = description self.__cog_description__ = description
def walk_commands(self): def walk_commands(self) -> Generator[Command, None, None]:
"""An iterator that recursively walks through this cog's commands and subcommands. """An iterator that recursively walks through this cog's commands and subcommands.
Yields Yields
@ -237,7 +261,7 @@ class Cog(metaclass=CogMeta):
if isinstance(command, GroupMixin): if isinstance(command, GroupMixin):
yield from command.walk_commands() yield from command.walk_commands()
def get_listeners(self): def get_listeners(self) -> List[Tuple[str, Callable[..., Any]]]:
"""Returns a :class:`list` of (name, function) listener pairs that are defined in this cog. """Returns a :class:`list` of (name, function) listener pairs that are defined in this cog.
Returns Returns
@ -248,12 +272,12 @@ class Cog(metaclass=CogMeta):
return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__] return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__]
@classmethod @classmethod
def _get_overridden_method(cls, method): def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]:
"""Return None if the method is not overridden. Otherwise returns the overridden method.""" """Return None if the method is not overridden. Otherwise returns the overridden method."""
return getattr(method.__func__, '__cog_special_method__', method) return getattr(method.__func__, '__cog_special_method__', method)
@classmethod @classmethod
def listener(cls, name=None): def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]:
"""A decorator that marks a function as a listener. """A decorator that marks a function as a listener.
This is the cog equivalent of :meth:`.Bot.listen`. This is the cog equivalent of :meth:`.Bot.listen`.
@ -271,10 +295,10 @@ class Cog(metaclass=CogMeta):
the name. the name.
""" """
if name is not None and not isinstance(name, str): if name is not MISSING and not isinstance(name, str):
raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.') raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.')
def decorator(func): def decorator(func: FuncT) -> FuncT:
actual = func actual = func
if isinstance(actual, staticmethod): if isinstance(actual, staticmethod):
actual = actual.__func__ actual = actual.__func__
@ -293,7 +317,7 @@ class Cog(metaclass=CogMeta):
return func return func
return decorator return decorator
def has_error_handler(self): def has_error_handler(self) -> bool:
""":class:`bool`: Checks whether the cog has an error handler. """:class:`bool`: Checks whether the cog has an error handler.
.. versionadded:: 1.7 .. versionadded:: 1.7
@ -301,7 +325,7 @@ class Cog(metaclass=CogMeta):
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__') return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
@_cog_special_method @_cog_special_method
def cog_unload(self): def cog_unload(self) -> None:
"""A special method that is called when the cog gets removed. """A special method that is called when the cog gets removed.
This function **cannot** be a coroutine. It must be a regular This function **cannot** be a coroutine. It must be a regular
@ -312,7 +336,7 @@ class Cog(metaclass=CogMeta):
pass pass
@_cog_special_method @_cog_special_method
def bot_check_once(self, ctx): def bot_check_once(self, ctx: Context) -> bool:
"""A special method that registers as a :meth:`.Bot.check_once` """A special method that registers as a :meth:`.Bot.check_once`
check. check.
@ -322,7 +346,7 @@ class Cog(metaclass=CogMeta):
return True return True
@_cog_special_method @_cog_special_method
def bot_check(self, ctx): def bot_check(self, ctx: Context) -> bool:
"""A special method that registers as a :meth:`.Bot.check` """A special method that registers as a :meth:`.Bot.check`
check. check.
@ -332,7 +356,7 @@ class Cog(metaclass=CogMeta):
return True return True
@_cog_special_method @_cog_special_method
def cog_check(self, ctx): def cog_check(self, ctx: Context) -> bool:
"""A special method that registers as a :func:`~discord.ext.commands.check` """A special method that registers as a :func:`~discord.ext.commands.check`
for every command and subcommand in this cog. for every command and subcommand in this cog.
@ -342,7 +366,7 @@ class Cog(metaclass=CogMeta):
return True return True
@_cog_special_method @_cog_special_method
async def cog_command_error(self, ctx, error): async def cog_command_error(self, ctx: Context, error: Exception) -> None:
"""A special method that is called whenever an error """A special method that is called whenever an error
is dispatched inside this cog. is dispatched inside this cog.
@ -361,7 +385,7 @@ class Cog(metaclass=CogMeta):
pass pass
@_cog_special_method @_cog_special_method
async def cog_before_invoke(self, ctx): async def cog_before_invoke(self, ctx: Context) -> None:
"""A special method that acts as a cog local pre-invoke hook. """A special method that acts as a cog local pre-invoke hook.
This is similar to :meth:`.Command.before_invoke`. This is similar to :meth:`.Command.before_invoke`.
@ -376,7 +400,7 @@ class Cog(metaclass=CogMeta):
pass pass
@_cog_special_method @_cog_special_method
async def cog_after_invoke(self, ctx): async def cog_after_invoke(self, ctx: Context) -> None:
"""A special method that acts as a cog local post-invoke hook. """A special method that acts as a cog local post-invoke hook.
This is similar to :meth:`.Command.after_invoke`. This is similar to :meth:`.Command.after_invoke`.
@ -390,7 +414,7 @@ class Cog(metaclass=CogMeta):
""" """
pass pass
def _inject(self, bot): def _inject(self: CogT, bot: BotBase) -> CogT:
cls = self.__class__ cls = self.__class__
# realistically, the only thing that can cause loading errors # realistically, the only thing that can cause loading errors
@ -425,7 +449,7 @@ class Cog(metaclass=CogMeta):
return self return self
def _eject(self, bot): def _eject(self, bot: BotBase) -> None:
cls = self.__class__ cls = self.__class__
try: try:

149
discord/ext/commands/context.py

@ -21,16 +21,52 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import inspect
import re
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union
import discord.abc import discord.abc
import discord.utils import discord.utils
import re
from discord.message import Message
if TYPE_CHECKING:
from typing_extensions import ParamSpec
from discord.abc import MessageableChannel
from discord.guild import Guild
from discord.member import Member
from discord.state import ConnectionState
from discord.user import ClientUser, User
from discord.voice_client import VoiceProtocol
from .bot import Bot, AutoShardedBot
from .cog import Cog
from .core import Command
from .help import HelpCommand
from .view import StringView
__all__ = ( __all__ = (
'Context', 'Context',
) )
class Context(discord.abc.Messageable): MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog")
if TYPE_CHECKING:
P = ParamSpec('P')
else:
P = TypeVar('P')
class Context(discord.abc.Messageable, Generic[BotT]):
r"""Represents the context in which a command is being invoked under. r"""Represents the context in which a command is being invoked under.
This class contains a lot of meta data to help you understand more about This class contains a lot of meta data to help you understand more about
@ -58,11 +94,11 @@ class Context(discord.abc.Messageable):
This is only of use for within converters. This is only of use for within converters.
.. versionadded:: 2.0 .. versionadded:: 2.0
prefix: :class:`str` prefix: Optional[:class:`str`]
The prefix that was used to invoke the command. The prefix that was used to invoke the command.
command: :class:`Command` command: Optional[:class:`Command`]
The command that is being invoked currently. The command that is being invoked currently.
invoked_with: :class:`str` invoked_with: Optional[:class:`str`]
The command name that triggered this invocation. Useful for finding out The command name that triggered this invocation. Useful for finding out
which alias called the command. which alias called the command.
invoked_parents: List[:class:`str`] invoked_parents: List[:class:`str`]
@ -73,7 +109,7 @@ class Context(discord.abc.Messageable):
.. versionadded:: 1.7 .. versionadded:: 1.7
invoked_subcommand: :class:`Command` invoked_subcommand: Optional[:class:`Command`]
The subcommand that was invoked. The subcommand that was invoked.
If no valid subcommand was invoked then this is equal to ``None``. If no valid subcommand was invoked then this is equal to ``None``.
subcommand_passed: Optional[:class:`str`] subcommand_passed: Optional[:class:`str`]
@ -86,23 +122,38 @@ class Context(discord.abc.Messageable):
or invoked. or invoked.
""" """
def __init__(self, **attrs): def __init__(self,
self.message = attrs.pop('message', None) *,
self.bot = attrs.pop('bot', None) message: Message,
self.args = attrs.pop('args', []) bot: BotT,
self.kwargs = attrs.pop('kwargs', {}) view: StringView,
self.prefix = attrs.pop('prefix') args: List[Any] = MISSING,
self.command = attrs.pop('command', None) kwargs: Dict[str, Any] = MISSING,
self.view = attrs.pop('view', None) prefix: Optional[str] = None,
self.invoked_with = attrs.pop('invoked_with', None) command: Optional[Command] = None,
self.invoked_parents = attrs.pop('invoked_parents', []) invoked_with: Optional[str] = None,
self.invoked_subcommand = attrs.pop('invoked_subcommand', None) invoked_parents: List[str] = MISSING,
self.subcommand_passed = attrs.pop('subcommand_passed', None) invoked_subcommand: Optional[Command] = None,
self.command_failed = attrs.pop('command_failed', False) subcommand_passed: Optional[str] = None,
self.current_parameter = attrs.pop('current_parameter', None) command_failed: bool = False,
self._state = self.message._state current_parameter: Optional[inspect.Parameter] = None,
):
async def invoke(self, command, /, *args, **kwargs): self.message: Message = message
self.bot: BotT = bot
self.args: List[Any] = args or []
self.kwargs: Dict[str, Any] = kwargs or {}
self.prefix: Optional[str] = prefix
self.command: Optional[Command] = command
self.view: StringView = view
self.invoked_with: Optional[str] = invoked_with
self.invoked_parents: List[str] = invoked_parents or []
self.invoked_subcommand: Optional[Command] = invoked_subcommand
self.subcommand_passed: Optional[str] = subcommand_passed
self.command_failed: bool = command_failed
self.current_parameter: Optional[inspect.Parameter] = current_parameter
self._state: ConnectionState = self.message._state
async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
r"""|coro| r"""|coro|
Calls a command with the arguments given. Calls a command with the arguments given.
@ -133,17 +184,9 @@ class Context(discord.abc.Messageable):
TypeError TypeError
The command argument to invoke is missing. The command argument to invoke is missing.
""" """
arguments = [] return await command(self, *args, **kwargs)
if command.cog is not None:
arguments.append(command.cog)
arguments.append(self)
arguments.extend(args)
ret = await command.callback(*arguments, **kwargs) async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> None:
return ret
async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True):
"""|coro| """|coro|
Calls the command again. Calls the command again.
@ -187,7 +230,7 @@ class Context(discord.abc.Messageable):
if restart: if restart:
to_call = cmd.root_parent or cmd to_call = cmd.root_parent or cmd
view.index = len(self.prefix) view.index = len(self.prefix or '')
view.previous = 0 view.previous = 0
self.invoked_parents = [] self.invoked_parents = []
self.invoked_with = view.get_word() # advance to get the root command self.invoked_with = view.get_word() # advance to get the root command
@ -206,20 +249,23 @@ class Context(discord.abc.Messageable):
self.subcommand_passed = subcommand_passed self.subcommand_passed = subcommand_passed
@property @property
def valid(self): def valid(self) -> bool:
""":class:`bool`: Checks if the invocation context is valid to be invoked with.""" """:class:`bool`: Checks if the invocation context is valid to be invoked with."""
return self.prefix is not None and self.command is not None return self.prefix is not None and self.command is not None
async def _get_channel(self): async def _get_channel(self) -> discord.abc.Messageable:
return self.channel return self.channel
@property @property
def clean_prefix(self): def clean_prefix(self) -> str:
""":class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``. """:class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
user = self.guild.me if self.guild else self.bot.user if self.prefix is None:
return ''
user = self.me
# this breaks if the prefix mention is not the bot itself but I # this breaks if the prefix mention is not the bot itself but I
# consider this to be an *incredibly* strange use case. I'd rather go # consider this to be an *incredibly* strange use case. I'd rather go
# for this common use case rather than waste performance for the # for this common use case rather than waste performance for the
@ -228,7 +274,7 @@ class Context(discord.abc.Messageable):
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix) return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix)
@property @property
def cog(self): def cog(self) -> Optional[Cog]:
"""Optional[:class:`.Cog`]: Returns the cog associated with this context's command. None if it does not exist.""" """Optional[:class:`.Cog`]: Returns the cog associated with this context's command. None if it does not exist."""
if self.command is None: if self.command is None:
@ -236,38 +282,39 @@ class Context(discord.abc.Messageable):
return self.command.cog return self.command.cog
@discord.utils.cached_property @discord.utils.cached_property
def guild(self): def guild(self) -> Optional[Guild]:
"""Optional[:class:`.Guild`]: Returns the guild associated with this context's command. None if not available.""" """Optional[:class:`.Guild`]: Returns the guild associated with this context's command. None if not available."""
return self.message.guild return self.message.guild
@discord.utils.cached_property @discord.utils.cached_property
def channel(self): def channel(self) -> MessageableChannel:
"""Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command. """Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command.
Shorthand for :attr:`.Message.channel`. Shorthand for :attr:`.Message.channel`.
""" """
return self.message.channel return self.message.channel
@discord.utils.cached_property @discord.utils.cached_property
def author(self): def author(self) -> Union[User, Member]:
"""Union[:class:`~discord.User`, :class:`.Member`]: """Union[:class:`~discord.User`, :class:`.Member`]:
Returns the author associated with this context's command. Shorthand for :attr:`.Message.author` Returns the author associated with this context's command. Shorthand for :attr:`.Message.author`
""" """
return self.message.author return self.message.author
@discord.utils.cached_property @discord.utils.cached_property
def me(self): def me(self) -> Union[Member, ClientUser]:
"""Union[:class:`.Member`, :class:`.ClientUser`]: """Union[:class:`.Member`, :class:`.ClientUser`]:
Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message contexts. Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message contexts.
""" """
return self.guild.me if self.guild is not None else self.bot.user # bot.user will never be None at this point.
return self.guild.me if self.guild is not None else self.bot.user # type: ignore
@property @property
def voice_client(self): def voice_client(self) -> Optional[VoiceProtocol]:
r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable."""
g = self.guild g = self.guild
return g.voice_client if g else None return g.voice_client if g else None
async def send_help(self, *args): async def send_help(self, *args: Any) -> Any:
"""send_help(entity=<bot>) """send_help(entity=<bot>)
|coro| |coro|
@ -319,12 +366,12 @@ class Context(discord.abc.Messageable):
return None return None
entity = args[0] entity = args[0]
if entity is None:
return None
if isinstance(entity, str): if isinstance(entity, str):
entity = bot.get_cog(entity) or bot.get_command(entity) entity = bot.get_cog(entity) or bot.get_command(entity)
if entity is None:
return None
try: try:
entity.qualified_name entity.qualified_name
except AttributeError: except AttributeError:
@ -348,6 +395,6 @@ class Context(discord.abc.Messageable):
except CommandError as e: except CommandError as e:
await cmd.on_help_command_error(self, e) await cmd.on_help_command_error(self, e)
@discord.utils.copy_doc(discord.Message.reply) @discord.utils.copy_doc(Message.reply)
async def reply(self, content=None, **kwargs): async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message:
return await self.message.reply(content, **kwargs) return await self.message.reply(content, **kwargs)

557
discord/ext/commands/core.py

File diff suppressed because it is too large

8
discord/ext/commands/help.py

@ -27,11 +27,17 @@ import copy
import functools import functools
import inspect import inspect
import re import re
from typing import Optional, TYPE_CHECKING
import discord.utils import discord.utils
from .core import Group, Command from .core import Group, Command
from .errors import CommandError from .errors import CommandError
if TYPE_CHECKING:
from .context import Context
__all__ = ( __all__ = (
'Paginator', 'Paginator',
'HelpCommand', 'HelpCommand',
@ -320,7 +326,7 @@ class HelpCommand:
self.command_attrs = attrs = options.pop('command_attrs', {}) self.command_attrs = attrs = options.pop('command_attrs', {})
attrs.setdefault('name', 'help') attrs.setdefault('name', 'help')
attrs.setdefault('help', 'Shows this message') attrs.setdefault('help', 'Shows this message')
self.context = None self.context: Optional[Context] = None
self._command_impl = _HelpCommandImpl(self, **self.command_attrs) self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
def copy(self): def copy(self):

Loading…
Cancel
Save