diff --git a/.gitignore b/.gitignore index c0f7693..b6d3025 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ storage.json .cache/ .benchmarks/ __pycache__ +.venv # Documentation stuff docs/api/ diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 93962d1..0679381 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -10,7 +10,7 @@ from holster.threadlocal import ThreadLocal from gevent.pywsgi import WSGIServer from disco.types.guild import GuildMember -from disco.bot.plugin import Plugin +from disco.bot.plugin import find_loadable_plugins from disco.bot.command import CommandEvent, CommandLevels from disco.bot.storage import Storage from disco.util.config import Config @@ -499,12 +499,10 @@ class Bot(LoggingClass): mod = importlib.import_module(path) loaded = False - for entry in map(lambda i: getattr(mod, i), dir(mod)): - if inspect.isclass(entry) and issubclass(entry, Plugin) and not entry == Plugin: - if getattr(entry, '_shallow', False) and Plugin in entry.__bases__: - continue - loaded = True - self.add_plugin(entry, config) + plugins = find_loadable_plugins(mod) + for plugin in plugins: + loaded = True + self.add_plugin(plugin, config) if not loaded: raise Exception('Could not find any plugins to load within module {}'.format(path)) diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index a5e589e..a19f773 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -3,6 +3,7 @@ import types import gevent import inspect import weakref +import warnings import functools from gevent.event import AsyncResult @@ -12,6 +13,56 @@ from disco.util.logging import LoggingClass from disco.bot.command import Command, CommandError +# Contains a list of classes which will be excluded when auto discovering plugins +# to load. This allows anyone to create subclasses of Plugin that act as a base +# plugin class within their project/bot. +_plugin_base_classes = set() + + +def register_plugin_base_class(cls): + """ + This function registers the given class under an internal registry of plugin + base classes. This will cause the class passed to behave exactly like the + builtin `Plugin` class. + + This is particularly useful if you wish to subclass `Plugin` to create a new + base class that other plugins in your project inherit from, but do not want + the automatic plugin loading to consider the class for loading. + """ + if not inspect.isclass(cls): + raise TypeError('cls must be a class') + + _plugin_base_classes.add(cls) + return cls + + +def find_loadable_plugins(mod): + """ + Generator which produces a list of loadable plugins given a Python module. This + function will exclude any plugins which are registered as a plugin base class + via the `register_plugin_base_class` function. + """ + module_attributes = (getattr(mod, attr) for attr in dir(mod)) + for modattr in module_attributes: + if not inspect.isclass(modattr): + continue + + if not issubclass(modattr, Plugin): + continue + + if modattr in _plugin_base_classes: + continue + + if getattr(modattr, '_shallow', False) and Plugin in modattr.__bases__: + warnings.warn( + 'Setting _shallow to avoid plugin loading has been deprecated, see `register_plugin_base_class`', + DeprecationWarning, + ) + continue + + yield modattr + + class BasePluginDeco(object): Prio = Priority @@ -151,6 +202,7 @@ class PluginDeco(BasePluginDeco): parser = BasePluginDeco +@register_plugin_base_class class Plugin(LoggingClass, PluginDeco): """ A plugin is a set of listeners/commands which can be loaded/unloaded by a bot. diff --git a/tests/bot/plugin.py b/tests/bot/plugin.py new file mode 100644 index 0000000..d1ff34a --- /dev/null +++ b/tests/bot/plugin.py @@ -0,0 +1,47 @@ +import pytest +from disco.bot.plugin import Plugin, register_plugin_base_class, find_loadable_plugins + + +def _test_module(**kwargs): + class MyTestModule(object): + pass + + m = MyTestModule() + + for k, v in kwargs.items(): + setattr(m, k, v) + + return m + + +def test_shallow_attribute_deprecated(): + class MyPluginBaseClass(Plugin): + _shallow = True + + class RegularPlugin(Plugin): + pass + + with pytest.warns(DeprecationWarning): + plugins = list(find_loadable_plugins(_test_module( + MyPluginBaseClass=MyPluginBaseClass, + RegularPlugin=RegularPlugin, + ))) + + assert plugins == [RegularPlugin] + + +def test_register_plugin_base_class(): + class MyPluginBaseClass(Plugin): + pass + + class RegularPlugin(MyPluginBaseClass): + pass + + register_plugin_base_class(MyPluginBaseClass) + + plugins = list(find_loadable_plugins(_test_module( + MyPluginBaseClass=MyPluginBaseClass, + RegularPlugin=RegularPlugin, + ))) + + assert plugins == [RegularPlugin]