From 839682291bfdc9bb758a89b240fc7e8f7ed4dc3f Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 19 Apr 2019 09:50:35 -0700 Subject: [PATCH] fix: handle loading of user permission levels --- disco/bot/bot.py | 11 ++++++++--- disco/types/base.py | 16 +++------------- disco/util/enum.py | 20 ++++++++++++++++++++ 3 files changed, 31 insertions(+), 16 deletions(-) create mode 100644 disco/util/enum.py diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 5db6e29..d82bc1d 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -16,6 +16,7 @@ from disco.util.config import Config from disco.util.logging import LoggingClass from disco.util.serializer import Serializer from disco.util.threadlocal import ThreadLocal +from disco.util.enum import get_enum_value_by_name class BotConfig(Config): @@ -185,9 +186,13 @@ class Bot(LoggingClass): for plugin_mod in self.config.plugins: self.add_plugin_module(plugin_mod) - # Convert level mapping - for k, v in list(six.iteritems(self.config.levels)): - self.config.levels[int(k) if k.isdigit() else k] = CommandLevels.get(v) + # Convert our configured mapping of entities to levels into something + # we can actually use. This ensures IDs are converted properly, and maps + # any level names (e.g. `role_id: admin`) map to their numerical values. + for entity_id, level in six.iteritems(self.config.levels): + entity_id = int(entity_id) if str(entity_id).isdigit() else entity_id + level = int(level) if str(level).isdigit() else get_enum_value_by_name(CommandLevels, level) + self.config.levels[entity_id] = level @classmethod def from_cli(cls, *plugins): diff --git a/disco/types/base.py b/disco/types/base.py index 94913a8..c876f9a 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -7,6 +7,7 @@ from datetime import datetime as real_datetime from disco.util.chains import Chainable from disco.util.hashmap import HashMap +from disco.util.enum import get_enum_members DATETIME_FORMATS = [ '%Y-%m-%dT%H:%M:%S.%f', @@ -191,23 +192,12 @@ def snowflake(data): return int(data) if data else None -def _enum_attrs(enum): - for k, v in six.iteritems(enum.__dict__): - if not isinstance(k, six.string_types): - continue - - if k.startswith('_') or not k.isupper(): - continue - - yield k, v - - def enum(typ): def _f(data): if data is None: return None - for k, v in _enum_attrs(typ): + for k, v in get_enum_members(typ): if isinstance(data, six.string_types) and k == data.upper(): return v elif k == data or v == data: @@ -396,7 +386,7 @@ class Model(six.with_metaclass(ModelMeta, Chainable)): if ignore and name in ignore: continue - if getattr(self, name) == UNSET: + if getattr(self, name) is UNSET: continue obj[name] = field.serialize(getattr(self, name), field) return obj diff --git a/disco/util/enum.py b/disco/util/enum.py new file mode 100644 index 0000000..d633d01 --- /dev/null +++ b/disco/util/enum.py @@ -0,0 +1,20 @@ +import six + + +def get_enum_members(enum): + for k, v in six.iteritems(enum.__dict__): + if not isinstance(k, six.string_types): + continue + + if k.startswith('_') or not k.isupper(): + continue + + yield k, v + + +def get_enum_value_by_name(enum, name): + name = name.lower() + + for k, v in get_enum_members(enum): + if k.lower() == name: + return v