Browse Source

[commands] Don't raise ExtensionNotFound for ImportErrors in modules

Now loading an extension that _contains_ a failed import will fail
with ExtensionFailed, rather than ExtensionNotFound.
pull/2255/head
Benjamin Mintz 6 years ago
committed by Rapptz
parent
commit
0a21591d0c
  1. 24
      discord/ext/commands/bot.py
  2. 16
      discord/ext/commands/errors.py

24
discord/ext/commands/bot.py

@ -27,7 +27,7 @@ DEALINGS IN THE SOFTWARE.
import asyncio import asyncio
import collections import collections
import inspect import inspect
import importlib import importlib.util
import sys import sys
import traceback import traceback
import re import re
@ -588,12 +588,17 @@ class BotBase(GroupMixin):
if _is_submodule(name, module): if _is_submodule(name, module):
del sys.modules[module] del sys.modules[module]
def _load_from_module_spec(self, lib, key): def _load_from_module_spec(self, spec, key):
# precondition: key not in self.__extensions # precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(lib)
except Exception as e:
raise errors.ExtensionFailed(key, e) from e
try: try:
setup = getattr(lib, 'setup') setup = getattr(lib, 'setup')
except AttributeError: except AttributeError:
del sys.modules[key]
raise errors.NoEntryPointError(key) raise errors.NoEntryPointError(key)
try: try:
@ -603,7 +608,7 @@ class BotBase(GroupMixin):
self._call_module_finalizers(lib, key) self._call_module_finalizers(lib, key)
raise errors.ExtensionFailed(key, e) from e raise errors.ExtensionFailed(key, e) from e
else: else:
self.__extensions[key] = lib sys.modules[key] = self.__extensions[key] = lib
def load_extension(self, name): def load_extension(self, name):
"""Loads an extension. """Loads an extension.
@ -637,12 +642,11 @@ class BotBase(GroupMixin):
if name in self.__extensions: if name in self.__extensions:
raise errors.ExtensionAlreadyLoaded(name) raise errors.ExtensionAlreadyLoaded(name)
try: spec = importlib.util.find_spec(name)
lib = importlib.import_module(name) if spec is None:
except ImportError as e: raise errors.ExtensionNotFound(name)
raise errors.ExtensionNotFound(name, e) from e
else: self._load_from_module_spec(spec, name)
self._load_from_module_spec(lib, name)
def unload_extension(self, name): def unload_extension(self, name):
"""Unloads an extension. """Unloads an extension.

16
discord/ext/commands/errors.py

@ -503,7 +503,7 @@ class NoEntryPointError(ExtensionError):
super().__init__("Extension {!r} has no 'setup' function.".format(name), name=name) super().__init__("Extension {!r} has no 'setup' function.".format(name), name=name)
class ExtensionFailed(ExtensionError): class ExtensionFailed(ExtensionError):
"""An exception raised when an extension failed to load during execution of the ``setup`` entry point. """An exception raised when an extension failed to load during execution of the module or ``setup`` entry point.
This inherits from :exc:`ExtensionError` This inherits from :exc:`ExtensionError`
@ -521,19 +521,21 @@ class ExtensionFailed(ExtensionError):
super().__init__(fmt.format(name, original), name=name) super().__init__(fmt.format(name, original), name=name)
class ExtensionNotFound(ExtensionError): class ExtensionNotFound(ExtensionError):
"""An exception raised when an extension failed to be imported. """An exception raised when an extension is not found.
This inherits from :exc:`ExtensionError` This inherits from :exc:`ExtensionError`
.. versionchanged:: 1.3.0
Made the ``original`` attribute always None.
Attributes Attributes
----------- -----------
name: :class:`str` name: :class:`str`
The extension that had the error. The extension that had the error.
original: :exc:`ImportError` original: :class:`NoneType`
The original exception that was raised. You can also get this via Always ``None`` for backwards compatibility.
the ``__cause__`` attribute.
""" """
def __init__(self, name, original): def __init__(self, name, original=None):
self.original = original self.original = None
fmt = 'Extension {0!r} could not be loaded.' fmt = 'Extension {0!r} could not be loaded.'
super().__init__(fmt.format(name), name=name) super().__init__(fmt.format(name), name=name)

Loading…
Cancel
Save