From c8994b203eac90fc1cbf85df144414acfd872c5f Mon Sep 17 00:00:00 2001 From: andrei Date: Wed, 10 Jan 2018 18:42:35 -0800 Subject: [PATCH] Play around with some better abstractions --- disco/bot/bot.py | 4 ++-- disco/bot/command.py | 25 +++++++++++-------------- disco/bot/plugin.py | 5 +++-- disco/client.py | 11 ++++++++++- disco/gateway/client.py | 41 +++++++++++++++++++++++++++++------------ disco/gateway/events.py | 3 ++- 6 files changed, 57 insertions(+), 32 deletions(-) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index f90b882..0d9740a 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -258,7 +258,7 @@ class Bot(LoggingClass): Computes a single regex which matches all possible command combinations. """ commands = list(self.commands) - re_str = '|'.join(command.regex(grouped=False) for command in commands) + re_str = '|'.join(command.regex(self.group_abbrev, grouped=False) for command in commands) if re_str: self.command_matches_re = re.compile(re_str, re.I) else: @@ -326,7 +326,7 @@ class Bot(LoggingClass): options = [] for command in self.commands: - match = command.compiled_regex.match(content) + match = command.compiled_regex(self.group_abbrev).match(content) if match: options.append((command, match)) return sorted(options, key=lambda obj: obj[0].group is None) diff --git a/disco/bot/command.py b/disco/bot/command.py index 75904b9..6915635 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -120,8 +120,6 @@ class Command(object): Attributes ---------- - plugin : :class:`disco.bot.plugin.Plugin` - The plugin this command is a member of. func : function The function which is called when this command is triggered. trigger : str @@ -135,8 +133,7 @@ class Command(object): is_regex : Optional[bool] Whether the triggers for this command should be treated as raw regex. """ - def __init__(self, plugin, func, trigger, *args, **kwargs): - self.plugin = plugin + def __init__(self, func, trigger, *args, **kwargs): self.func = func self.triggers = [trigger] @@ -216,6 +213,8 @@ class Command(object): if parser: self.parser = PluginArgumentParser(prog=self.name, add_help=False) + self._cached_regex = None + @staticmethod def mention_type(getters, reg=None, user=False, allow_plain=False): def _f(ctx, raw): @@ -244,14 +243,12 @@ class Command(object): raise TypeError('Cannot resolve mention: {}'.format(raw)) return _f - @simple_cached_property - def compiled_regex(self): - """ - A compiled version of this command's regex. - """ - return re.compile(self.regex(), re.I) + def compiled_regex(self, group_abbrev): + if not self._cached_regex: + self._cached_regex = re.compile(self.regex(group_abbrev), re.I) + return self._cached_regex - def regex(self, grouped=True): + def regex(self, group_abbrev, grouped=True): """ The regex string that defines/triggers this command. """ @@ -260,8 +257,8 @@ class Command(object): else: group = '' if self.group: - if self.group in self.plugin.bot.group_abbrev: - rest = self.plugin.bot.group_abbrev[self.group] + if self.group in group_abbrev: + rest = group_abbrev[self.group] group = '{}(?:{}) '.format(rest, ''.join(c + u'?' for c in self.group[len(rest):])) else: group = self.group + ' ' @@ -303,4 +300,4 @@ class Command(object): kwargs = {} kwargs.update(self.context) kwargs.update(parsed_kwargs) - return self.plugin.dispatch('command', self, event, **kwargs) + return (event, kwargs) diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index 8422cb7..930b867 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -296,7 +296,8 @@ class Plugin(LoggingClass, PluginDeco): if not event.command.oob: self.greenlets.add(gevent.getcurrent()) try: - return event.command.execute(event) + command_event, kwargs = event.command.execute(event) + return self.plugin.dispatch('command', event.command, command_event, **kwargs) except CommandError as e: event.msg.reply(e.msg) return False @@ -377,7 +378,7 @@ class Plugin(LoggingClass, PluginDeco): Keyword arguments to pass onto the :class:`disco.bot.command.Command` object. """ - self.commands.append(Command(self, func, *args, **kwargs)) + self.commands.append(Command(func, *args, **kwargs)) def register_schedule(self, func, interval, repeat=True, init=True): """ diff --git a/disco/client.py b/disco/client.py index 14ff5ea..a14aa20 100644 --- a/disco/client.py +++ b/disco/client.py @@ -92,7 +92,16 @@ class Client(LoggingClass): self.packets = Emitter() self.api = APIClient(self.config.token, self) - self.gw = GatewayClient(self, self.config.max_reconnects, self.config.encoder) + self.gw = GatewayClient( + token=self.config.token, + shard_id=self.config.shard_id, + shard_count=self.config.shard_count, + max_reconnects=self.config.max_reconnects, + encoder=self.config.encoder, + events=self.events, + packets=self.packets, + client=self, + ) self.state = State(self, StateConfig(self.config.get('state', {}))) if self.config.manhole_enable: diff --git a/disco/gateway/client.py b/disco/gateway/client.py index 92ad7ff..246947c 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -3,6 +3,7 @@ import zlib import six import ssl +from holster.emitter import Emitter from websocket import ABNF from disco.gateway.packets import OPCode, RECV, SEND @@ -19,15 +20,29 @@ ZLIB_SUFFIX = b'\x00\x00\xff\xff' class GatewayClient(LoggingClass): GATEWAY_VERSION = 6 - def __init__(self, client, max_reconnects=5, encoder='json', zlib_stream_enabled=True, ipc=None): + def __init__( + self, + token, + shard_id=0, + shard_count=1, + max_reconnects=5, + encoder='json', + zlib_stream_enabled=True, + ipc=None, + events=None, + packets=None, + client=None): super(GatewayClient, self).__init__() - self.client = client + self.token = token + self.shard_id = shard_id + self.shard_count = shard_count self.max_reconnects = max_reconnects self.encoder = ENCODERS[encoder] self.zlib_stream_enabled = zlib_stream_enabled - self.events = client.events - self.packets = client.packets + self.client = client + self.events = events or Emitter() + self.packets = packets or Emitter() # IPC for shards if ipc: @@ -88,7 +103,7 @@ class GatewayClient(LoggingClass): def handle_dispatch(self, packet): obj = GatewayEvent.from_dispatch(self.client, packet) self.log.debug('GatewayClient.handle_dispatch %s', obj.__class__.__name__) - self.client.events.emit(obj.__class__.__name__, obj) + self.events.emit(obj.__class__.__name__, obj) if self.replaying: self.replayed_events += 1 @@ -121,10 +136,12 @@ class GatewayClient(LoggingClass): def connect_and_run(self, gateway_url=None): if not gateway_url: - if not self._cached_gateway_url: + if not self._cached_gateway_url and self.client: self._cached_gateway_url = self.client.api.gateway_get()['url'] gateway_url = self._cached_gateway_url + else: + self._cached_gateway_url = gateway_url gateway_url += '?v={}&encoding={}'.format(self.GATEWAY_VERSION, self.encoder.TYPE) @@ -191,19 +208,19 @@ class GatewayClient(LoggingClass): self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq) self.replaying = True self.send(OPCode.RESUME, { - 'token': self.client.config.token, + 'token': self.token, 'session_id': self.session_id, 'seq': self.seq, }) else: self.log.info('WS Opened: sending identify payload') self.send(OPCode.IDENTIFY, { - 'token': self.client.config.token, + 'token': self.token, 'compress': True, 'large_threshold': 250, 'shard': [ - int(self.client.config.shard_id), - int(self.client.config.shard_count), + int(self.shard_id), + int(self.shard_count), ], 'properties': { '$os': 'linux', @@ -247,6 +264,6 @@ class GatewayClient(LoggingClass): # Reconnect self.connect_and_run() - def run(self): - gevent.spawn(self.connect_and_run) + def run(self, gateway_url=None): + gevent.spawn(self.connect_and_run, gateway_url=gateway_url) self.ws_event.wait() diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 9b76525..425e2a6 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -1,6 +1,7 @@ from __future__ import print_function import six +import copy from disco.types.user import User, Presence from disco.types.channel import Channel, PermissionOverwrite @@ -48,7 +49,7 @@ class GatewayEvent(six.with_metaclass(GatewayEventMeta, Model)): """ Create this GatewayEvent class from data and the client. """ - cls.raw_data = obj + cls.raw_data = copy.deepcopy(obj) # If this event is wrapping a model, pull its fields if hasattr(cls, '_wraps_model'):