|
|
@ -22,13 +22,18 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
|
|
|
DEALINGS IN THE SOFTWARE. |
|
|
|
""" |
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
|
|
|
|
|
|
import asyncio |
|
|
|
import collections |
|
|
|
import collections.abc |
|
|
|
import inspect |
|
|
|
import importlib.util |
|
|
|
import sys |
|
|
|
import traceback |
|
|
|
import types |
|
|
|
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union |
|
|
|
|
|
|
|
import discord |
|
|
|
|
|
|
@ -39,6 +44,15 @@ from . import errors |
|
|
|
from .help import HelpCommand, DefaultHelpCommand |
|
|
|
from .cog import Cog |
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
import importlib.machinery |
|
|
|
|
|
|
|
from discord.message import Message |
|
|
|
from ._types import ( |
|
|
|
Check, |
|
|
|
CoroFunc, |
|
|
|
) |
|
|
|
|
|
|
|
__all__ = ( |
|
|
|
'when_mentioned', |
|
|
|
'when_mentioned_or', |
|
|
@ -46,14 +60,21 @@ __all__ = ( |
|
|
|
'AutoShardedBot', |
|
|
|
) |
|
|
|
|
|
|
|
def when_mentioned(bot, msg): |
|
|
|
MISSING: Any = discord.utils.MISSING |
|
|
|
|
|
|
|
T = TypeVar('T') |
|
|
|
CFT = TypeVar('CFT', bound='CoroFunc') |
|
|
|
CXT = TypeVar('CXT', bound='Context') |
|
|
|
|
|
|
|
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: |
|
|
|
"""A callable that implements a command prefix equivalent to being mentioned. |
|
|
|
|
|
|
|
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. |
|
|
|
""" |
|
|
|
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] |
|
|
|
# bot.user will never be None when this is called |
|
|
|
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore |
|
|
|
|
|
|
|
def when_mentioned_or(*prefixes): |
|
|
|
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]: |
|
|
|
"""A callable that implements when mentioned or other prefixes provided. |
|
|
|
|
|
|
|
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. |
|
|
@ -89,7 +110,7 @@ def when_mentioned_or(*prefixes): |
|
|
|
|
|
|
|
return inner |
|
|
|
|
|
|
|
def _is_submodule(parent, child): |
|
|
|
def _is_submodule(parent: str, child: str) -> bool: |
|
|
|
return parent == child or child.startswith(parent + ".") |
|
|
|
|
|
|
|
class _DefaultRepr: |
|
|
@ -102,10 +123,10 @@ class BotBase(GroupMixin): |
|
|
|
def __init__(self, command_prefix, help_command=_default, description=None, **options): |
|
|
|
super().__init__(**options) |
|
|
|
self.command_prefix = command_prefix |
|
|
|
self.extra_events = {} |
|
|
|
self.__cogs = {} |
|
|
|
self.__extensions = {} |
|
|
|
self._checks = [] |
|
|
|
self.extra_events: Dict[str, List[CoroFunc]] = {} |
|
|
|
self.__cogs: Dict[str, Cog] = {} |
|
|
|
self.__extensions: Dict[str, types.ModuleType] = {} |
|
|
|
self._checks: List[Check] = [] |
|
|
|
self._check_once = [] |
|
|
|
self._before_invoke = None |
|
|
|
self._after_invoke = None |
|
|
@ -128,13 +149,14 @@ class BotBase(GroupMixin): |
|
|
|
|
|
|
|
# internal helpers |
|
|
|
|
|
|
|
def dispatch(self, event_name, *args, **kwargs): |
|
|
|
super().dispatch(event_name, *args, **kwargs) |
|
|
|
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None: |
|
|
|
# super() will resolve to Client |
|
|
|
super().dispatch(event_name, *args, **kwargs) # type: ignore |
|
|
|
ev = 'on_' + event_name |
|
|
|
for event in self.extra_events.get(ev, []): |
|
|
|
self._schedule_event(event, ev, *args, **kwargs) |
|
|
|
self._schedule_event(event, ev, *args, **kwargs) # type: ignore |
|
|
|
|
|
|
|
async def close(self): |
|
|
|
async def close(self) -> None: |
|
|
|
for extension in tuple(self.__extensions): |
|
|
|
try: |
|
|
|
self.unload_extension(extension) |
|
|
@ -147,9 +169,9 @@ class BotBase(GroupMixin): |
|
|
|
except Exception: |
|
|
|
pass |
|
|
|
|
|
|
|
await super().close() |
|
|
|
await super().close() # type: ignore |
|
|
|
|
|
|
|
async def on_command_error(self, context, exception): |
|
|
|
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None: |
|
|
|
"""|coro| |
|
|
|
|
|
|
|
The default command error handler provided by the bot. |
|
|
@ -175,7 +197,7 @@ class BotBase(GroupMixin): |
|
|
|
|
|
|
|
# global check registration |
|
|
|
|
|
|
|
def check(self, func): |
|
|
|
def check(self, func: T) -> T: |
|
|
|
r"""A decorator that adds a global check to the bot. |
|
|
|
|
|
|
|
A global check is similar to a :func:`.check` that is applied |
|
|
@ -200,10 +222,11 @@ class BotBase(GroupMixin): |
|
|
|
return ctx.command.qualified_name in allowed_commands |
|
|
|
|
|
|
|
""" |
|
|
|
self.add_check(func) |
|
|
|
# T was used instead of Check to ensure the type matches on return |
|
|
|
self.add_check(func) # type: ignore |
|
|
|
return func |
|
|
|
|
|
|
|
def add_check(self, func, *, call_once=False): |
|
|
|
def add_check(self, func: Check, *, call_once: bool = False) -> None: |
|
|
|
"""Adds a global check to the bot. |
|
|
|
|
|
|
|
This is the non-decorator interface to :meth:`.check` |
|
|
@ -223,7 +246,7 @@ class BotBase(GroupMixin): |
|
|
|
else: |
|
|
|
self._checks.append(func) |
|
|
|
|
|
|
|
def remove_check(self, func, *, call_once=False): |
|
|
|
def remove_check(self, func: Check, *, call_once: bool = False) -> None: |
|
|
|
"""Removes a global check from the bot. |
|
|
|
|
|
|
|
This function is idempotent and will not raise an exception |
|
|
@ -244,7 +267,7 @@ class BotBase(GroupMixin): |
|
|
|
except ValueError: |
|
|
|
pass |
|
|
|
|
|
|
|
def check_once(self, func): |
|
|
|
def check_once(self, func: CFT) -> CFT: |
|
|
|
r"""A decorator that adds a "call once" global check to the bot. |
|
|
|
|
|
|
|
Unlike regular global checks, this one is called only once |
|
|
@ -282,15 +305,16 @@ class BotBase(GroupMixin): |
|
|
|
self.add_check(func, call_once=True) |
|
|
|
return func |
|
|
|
|
|
|
|
async def can_run(self, ctx, *, call_once=False): |
|
|
|
async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool: |
|
|
|
data = self._check_once if call_once else self._checks |
|
|
|
|
|
|
|
if len(data) == 0: |
|
|
|
return True |
|
|
|
|
|
|
|
return await discord.utils.async_all(f(ctx) for f in data) |
|
|
|
# type-checker doesn't distinguish between functions and methods |
|
|
|
return await discord.utils.async_all(f(ctx) for f in data) # type: ignore |
|
|
|
|
|
|
|
async def is_owner(self, user): |
|
|
|
async def is_owner(self, user: discord.User) -> bool: |
|
|
|
"""|coro| |
|
|
|
|
|
|
|
Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of |
|
|
@ -319,7 +343,8 @@ class BotBase(GroupMixin): |
|
|
|
elif self.owner_ids: |
|
|
|
return user.id in self.owner_ids |
|
|
|
else: |
|
|
|
app = await self.application_info() |
|
|
|
|
|
|
|
app = await self.application_info() # type: ignore |
|
|
|
if app.team: |
|
|
|
self.owner_ids = ids = {m.id for m in app.team.members} |
|
|
|
return user.id in ids |
|
|
@ -327,7 +352,7 @@ class BotBase(GroupMixin): |
|
|
|
self.owner_id = owner_id = app.owner.id |
|
|
|
return user.id == owner_id |
|
|
|
|
|
|
|
def before_invoke(self, coro): |
|
|
|
def before_invoke(self, coro: CFT) -> CFT: |
|
|
|
"""A decorator that registers a coroutine as a pre-invoke hook. |
|
|
|
|
|
|
|
A pre-invoke hook is called directly before the command is |
|
|
@ -359,7 +384,7 @@ class BotBase(GroupMixin): |
|
|
|
self._before_invoke = coro |
|
|
|
return coro |
|
|
|
|
|
|
|
def after_invoke(self, coro): |
|
|
|
def after_invoke(self, coro: CFT) -> CFT: |
|
|
|
r"""A decorator that registers a coroutine as a post-invoke hook. |
|
|
|
|
|
|
|
A post-invoke hook is called directly after the command is |
|
|
@ -394,14 +419,14 @@ class BotBase(GroupMixin): |
|
|
|
|
|
|
|
# listener registration |
|
|
|
|
|
|
|
def add_listener(self, func, name=None): |
|
|
|
def add_listener(self, func: CoroFunc, name: str = MISSING) -> None: |
|
|
|
"""The non decorator alternative to :meth:`.listen`. |
|
|
|
|
|
|
|
Parameters |
|
|
|
----------- |
|
|
|
func: :ref:`coroutine <coroutine>` |
|
|
|
The function to call. |
|
|
|
name: Optional[:class:`str`] |
|
|
|
name: :class:`str` |
|
|
|
The name of the event to listen for. Defaults to ``func.__name__``. |
|
|
|
|
|
|
|
Example |
|
|
@ -416,7 +441,7 @@ class BotBase(GroupMixin): |
|
|
|
bot.add_listener(my_message, 'on_message') |
|
|
|
|
|
|
|
""" |
|
|
|
name = func.__name__ if name is None else name |
|
|
|
name = func.__name__ if name is MISSING else name |
|
|
|
|
|
|
|
if not asyncio.iscoroutinefunction(func): |
|
|
|
raise TypeError('Listeners must be coroutines') |
|
|
@ -426,7 +451,7 @@ class BotBase(GroupMixin): |
|
|
|
else: |
|
|
|
self.extra_events[name] = [func] |
|
|
|
|
|
|
|
def remove_listener(self, func, name=None): |
|
|
|
def remove_listener(self, func: CoroFunc, name: str = MISSING) -> None: |
|
|
|
"""Removes a listener from the pool of listeners. |
|
|
|
|
|
|
|
Parameters |
|
|
@ -438,7 +463,7 @@ class BotBase(GroupMixin): |
|
|
|
``func.__name__``. |
|
|
|
""" |
|
|
|
|
|
|
|
name = func.__name__ if name is None else name |
|
|
|
name = func.__name__ if name is MISSING else name |
|
|
|
|
|
|
|
if name in self.extra_events: |
|
|
|
try: |
|
|
@ -446,7 +471,7 @@ class BotBase(GroupMixin): |
|
|
|
except ValueError: |
|
|
|
pass |
|
|
|
|
|
|
|
def listen(self, name=None): |
|
|
|
def listen(self, name: str = MISSING) -> Callable[[CFT], CFT]: |
|
|
|
"""A decorator that registers another function as an external |
|
|
|
event listener. Basically this allows you to listen to multiple |
|
|
|
events from different places e.g. such as :func:`.on_ready` |
|
|
@ -476,7 +501,7 @@ class BotBase(GroupMixin): |
|
|
|
The function being listened to is not a coroutine. |
|
|
|
""" |
|
|
|
|
|
|
|
def decorator(func): |
|
|
|
def decorator(func: CFT) -> CFT: |
|
|
|
self.add_listener(func, name) |
|
|
|
return func |
|
|
|
|
|
|
@ -528,7 +553,7 @@ class BotBase(GroupMixin): |
|
|
|
cog = cog._inject(self) |
|
|
|
self.__cogs[cog_name] = cog |
|
|
|
|
|
|
|
def get_cog(self, name): |
|
|
|
def get_cog(self, name: str) -> Optional[Cog]: |
|
|
|
"""Gets the cog instance requested. |
|
|
|
|
|
|
|
If the cog is not found, ``None`` is returned instead. |
|
|
@ -547,7 +572,7 @@ class BotBase(GroupMixin): |
|
|
|
""" |
|
|
|
return self.__cogs.get(name) |
|
|
|
|
|
|
|
def remove_cog(self, name): |
|
|
|
def remove_cog(self, name: str) -> Optional[Cog]: |
|
|
|
"""Removes a cog from the bot and returns it. |
|
|
|
|
|
|
|
All registered commands and event listeners that the |
|
|
@ -578,13 +603,13 @@ class BotBase(GroupMixin): |
|
|
|
return cog |
|
|
|
|
|
|
|
@property |
|
|
|
def cogs(self): |
|
|
|
def cogs(self) -> Mapping[str, Cog]: |
|
|
|
"""Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog.""" |
|
|
|
return types.MappingProxyType(self.__cogs) |
|
|
|
|
|
|
|
# extensions |
|
|
|
|
|
|
|
def _remove_module_references(self, name): |
|
|
|
def _remove_module_references(self, name: str) -> None: |
|
|
|
# find all references to the module |
|
|
|
# remove the cogs registered from the module |
|
|
|
for cogname, cog in self.__cogs.copy().items(): |
|
|
@ -608,7 +633,7 @@ class BotBase(GroupMixin): |
|
|
|
for index in reversed(remove): |
|
|
|
del event_list[index] |
|
|
|
|
|
|
|
def _call_module_finalizers(self, lib, key): |
|
|
|
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: |
|
|
|
try: |
|
|
|
func = getattr(lib, 'teardown') |
|
|
|
except AttributeError: |
|
|
@ -626,12 +651,12 @@ class BotBase(GroupMixin): |
|
|
|
if _is_submodule(name, module): |
|
|
|
del sys.modules[module] |
|
|
|
|
|
|
|
def _load_from_module_spec(self, spec, key): |
|
|
|
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None: |
|
|
|
# precondition: key not in self.__extensions |
|
|
|
lib = importlib.util.module_from_spec(spec) |
|
|
|
sys.modules[key] = lib |
|
|
|
try: |
|
|
|
spec.loader.exec_module(lib) |
|
|
|
spec.loader.exec_module(lib) # type: ignore |
|
|
|
except Exception as e: |
|
|
|
del sys.modules[key] |
|
|
|
raise errors.ExtensionFailed(key, e) from e |
|
|
@ -652,13 +677,13 @@ class BotBase(GroupMixin): |
|
|
|
else: |
|
|
|
self.__extensions[key] = lib |
|
|
|
|
|
|
|
def _resolve_name(self, name, package): |
|
|
|
def _resolve_name(self, name: str, package: Optional[str]) -> str: |
|
|
|
try: |
|
|
|
return importlib.util.resolve_name(name, package) |
|
|
|
except ImportError: |
|
|
|
raise errors.ExtensionNotFound(name) |
|
|
|
|
|
|
|
def load_extension(self, name, *, package=None): |
|
|
|
def load_extension(self, name: str, *, package: Optional[str] = None) -> None: |
|
|
|
"""Loads an extension. |
|
|
|
|
|
|
|
An extension is a python module that contains commands, cogs, or |
|
|
@ -705,7 +730,7 @@ class BotBase(GroupMixin): |
|
|
|
|
|
|
|
self._load_from_module_spec(spec, name) |
|
|
|
|
|
|
|
def unload_extension(self, name, *, package=None): |
|
|
|
def unload_extension(self, name: str, *, package: Optional[str] = None) -> None: |
|
|
|
"""Unloads an extension. |
|
|
|
|
|
|
|
When the extension is unloaded, all commands, listeners, and cogs are |
|
|
@ -746,7 +771,7 @@ class BotBase(GroupMixin): |
|
|
|
self._remove_module_references(lib.__name__) |
|
|
|
self._call_module_finalizers(lib, name) |
|
|
|
|
|
|
|
def reload_extension(self, name, *, package=None): |
|
|
|
def reload_extension(self, name: str, *, package: Optional[str] = None) -> None: |
|
|
|
"""Atomically reloads an extension. |
|
|
|
|
|
|
|
This replaces the extension with the same extension, only refreshed. This is |
|
|
@ -802,7 +827,7 @@ class BotBase(GroupMixin): |
|
|
|
# if the load failed, the remnants should have been |
|
|
|
# cleaned from the load_extension function call |
|
|
|
# so let's load it from our old compiled library. |
|
|
|
lib.setup(self) |
|
|
|
lib.setup(self) # type: ignore |
|
|
|
self.__extensions[name] = lib |
|
|
|
|
|
|
|
# revert sys.modules back to normal and raise back to caller |
|
|
@ -810,18 +835,18 @@ class BotBase(GroupMixin): |
|
|
|
raise |
|
|
|
|
|
|
|
@property |
|
|
|
def extensions(self): |
|
|
|
def extensions(self) -> Mapping[str, types.ModuleType]: |
|
|
|
"""Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension.""" |
|
|
|
return types.MappingProxyType(self.__extensions) |
|
|
|
|
|
|
|
# help command stuff |
|
|
|
|
|
|
|
@property |
|
|
|
def help_command(self): |
|
|
|
def help_command(self) -> Optional[HelpCommand]: |
|
|
|
return self._help_command |
|
|
|
|
|
|
|
@help_command.setter |
|
|
|
def help_command(self, value): |
|
|
|
def help_command(self, value: Optional[HelpCommand]) -> None: |
|
|
|
if value is not None: |
|
|
|
if not isinstance(value, HelpCommand): |
|
|
|
raise TypeError('help_command must be a subclass of HelpCommand') |
|
|
@ -837,7 +862,7 @@ class BotBase(GroupMixin): |
|
|
|
|
|
|
|
# command processing |
|
|
|
|
|
|
|
async def get_prefix(self, message): |
|
|
|
async def get_prefix(self, message: Message) -> Union[List[str], str]: |
|
|
|
"""|coro| |
|
|
|
|
|
|
|
Retrieves the prefix the bot is listening to |
|
|
@ -875,7 +900,7 @@ class BotBase(GroupMixin): |
|
|
|
|
|
|
|
return ret |
|
|
|
|
|
|
|
async def get_context(self, message, *, cls=Context): |
|
|
|
async def get_context(self, message: Message, *, cls: Type[CXT] = Context) -> CXT: |
|
|
|
r"""|coro| |
|
|
|
|
|
|
|
Returns the invocation context from the message. |
|
|
@ -908,7 +933,7 @@ class BotBase(GroupMixin): |
|
|
|
view = StringView(message.content) |
|
|
|
ctx = cls(prefix=None, view=view, bot=self, message=message) |
|
|
|
|
|
|
|
if message.author.id == self.user.id: |
|
|
|
if message.author.id == self.user.id: # type: ignore |
|
|
|
return ctx |
|
|
|
|
|
|
|
prefix = await self.get_prefix(message) |
|
|
@ -945,11 +970,12 @@ class BotBase(GroupMixin): |
|
|
|
|
|
|
|
invoker = view.get_word() |
|
|
|
ctx.invoked_with = invoker |
|
|
|
ctx.prefix = invoked_prefix |
|
|
|
# type-checker fails to narrow invoked_prefix type. |
|
|
|
ctx.prefix = invoked_prefix # type: ignore |
|
|
|
ctx.command = self.all_commands.get(invoker) |
|
|
|
return ctx |
|
|
|
|
|
|
|
async def invoke(self, ctx): |
|
|
|
async def invoke(self, ctx: Context) -> None: |
|
|
|
"""|coro| |
|
|
|
|
|
|
|
Invokes the command given under the invocation context and |
|
|
@ -975,7 +1001,7 @@ class BotBase(GroupMixin): |
|
|
|
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found') |
|
|
|
self.dispatch('command_error', ctx, exc) |
|
|
|
|
|
|
|
async def process_commands(self, message): |
|
|
|
async def process_commands(self, message: Message) -> None: |
|
|
|
"""|coro| |
|
|
|
|
|
|
|
This function processes the commands that have been registered |
|
|
|