From db54d2a3921b3a7eae0c254e632121c19d62005a Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 21 Apr 2017 10:54:42 -0700 Subject: [PATCH] Fix issues with calculating abbreviations (closes #15) --- disco/bot/bot.py | 39 ++++++++++++++++++++++++++------------- tests/test_bot.py | 26 ++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 13 deletions(-) create mode 100644 tests/test_bot.py diff --git a/disco/bot/bot.py b/disco/bot/bot.py index e7ccad4..52b4aab 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -65,6 +65,7 @@ class BotConfig(Config): The directory plugin configuration is located within. """ levels = {} + plugins = [] plugin_config = {} commands_enabled = True @@ -196,28 +197,40 @@ class Bot(LoggingClass): Called when a plugin is loaded/unloaded to recompute internal state. """ if self.config.commands_group_abbrev: - self.compute_group_abbrev() + groups = set(command.group for command in self.commands if command.group) + self.group_abbrev = self.compute_group_abbrev(groups) self.compute_command_matches_re() - def compute_group_abbrev(self): + def compute_group_abbrev(self, groups): """ Computes all possible abbreviations for a command grouping. """ - self.group_abbrev = {} - groups = set(command.group for command in self.commands if command.group) - + # For the first pass, we just want to compute each groups possible + # abbreviations that don't conflict with eachother. + possible = {} for group in groups: - grp = group - while grp: - # If the group already exists, means someone else thought they - # could use it so we need yank it from them (and not use it) - if grp in list(six.itervalues(self.group_abbrev)): - self.group_abbrev = {k: v for k, v in six.iteritems(self.group_abbrev) if v != grp} + for index in range(len(group)): + current = group[:index] + if current in possible: + possible[current] = None else: - self.group_abbrev[group] = grp + possible[current] = group + + # Now, we want to compute the actual shortest abbreivation out of the + # possible ones + result = {} + for abbrev, group in six.iteritems(possible): + if not group: + continue - grp = grp[:-1] + if group in result: + if len(abbrev) < len(result[group]): + result[group] = abbrev + else: + result[group] = abbrev + + return result def compute_command_matches_re(self): """ diff --git a/tests/test_bot.py b/tests/test_bot.py new file mode 100644 index 0000000..5650dc2 --- /dev/null +++ b/tests/test_bot.py @@ -0,0 +1,26 @@ +from unittest import TestCase + +from disco.client import ClientConfig, Client +from disco.bot.bot import Bot + + +class TestBot(TestCase): + def setUp(self): + self.client = Client(ClientConfig( + {'config': 'TEST_TOKEN'} + )) + self.bot = Bot(self.client) + + def test_command_abbreviation(self): + groups = ['config', 'copy', 'copez', 'copypasta'] + result = self.bot.compute_group_abbrev(groups) + self.assertDictEqual(result, { + 'config': 'con', + 'copypasta': 'copy', + 'copez': 'cope', + }) + + def test_command_abbreivation_conflicting(self): + groups = ['cat', 'cap', 'caz', 'cas'] + result = self.bot.compute_group_abbrev(groups) + self.assertDictEqual(result, {})