Compare commits

...

1 Commits

Author SHA1 Message Date
andrei c8994b203e Play around with some better abstractions 7 years ago
  1. 4
      disco/bot/bot.py
  2. 25
      disco/bot/command.py
  3. 5
      disco/bot/plugin.py
  4. 11
      disco/client.py
  5. 41
      disco/gateway/client.py
  6. 3
      disco/gateway/events.py

4
disco/bot/bot.py

@ -258,7 +258,7 @@ class Bot(LoggingClass):
Computes a single regex which matches all possible command combinations.
"""
commands = list(self.commands)
re_str = '|'.join(command.regex(grouped=False) for command in commands)
re_str = '|'.join(command.regex(self.group_abbrev, grouped=False) for command in commands)
if re_str:
self.command_matches_re = re.compile(re_str, re.I)
else:
@ -326,7 +326,7 @@ class Bot(LoggingClass):
options = []
for command in self.commands:
match = command.compiled_regex.match(content)
match = command.compiled_regex(self.group_abbrev).match(content)
if match:
options.append((command, match))
return sorted(options, key=lambda obj: obj[0].group is None)

25
disco/bot/command.py

@ -120,8 +120,6 @@ class Command(object):
Attributes
----------
plugin : :class:`disco.bot.plugin.Plugin`
The plugin this command is a member of.
func : function
The function which is called when this command is triggered.
trigger : str
@ -135,8 +133,7 @@ class Command(object):
is_regex : Optional[bool]
Whether the triggers for this command should be treated as raw regex.
"""
def __init__(self, plugin, func, trigger, *args, **kwargs):
self.plugin = plugin
def __init__(self, func, trigger, *args, **kwargs):
self.func = func
self.triggers = [trigger]
@ -216,6 +213,8 @@ class Command(object):
if parser:
self.parser = PluginArgumentParser(prog=self.name, add_help=False)
self._cached_regex = None
@staticmethod
def mention_type(getters, reg=None, user=False, allow_plain=False):
def _f(ctx, raw):
@ -244,14 +243,12 @@ class Command(object):
raise TypeError('Cannot resolve mention: {}'.format(raw))
return _f
@simple_cached_property
def compiled_regex(self):
"""
A compiled version of this command's regex.
"""
return re.compile(self.regex(), re.I)
def compiled_regex(self, group_abbrev):
if not self._cached_regex:
self._cached_regex = re.compile(self.regex(group_abbrev), re.I)
return self._cached_regex
def regex(self, grouped=True):
def regex(self, group_abbrev, grouped=True):
"""
The regex string that defines/triggers this command.
"""
@ -260,8 +257,8 @@ class Command(object):
else:
group = ''
if self.group:
if self.group in self.plugin.bot.group_abbrev:
rest = self.plugin.bot.group_abbrev[self.group]
if self.group in group_abbrev:
rest = group_abbrev[self.group]
group = '{}(?:{}) '.format(rest, ''.join(c + u'?' for c in self.group[len(rest):]))
else:
group = self.group + ' '
@ -303,4 +300,4 @@ class Command(object):
kwargs = {}
kwargs.update(self.context)
kwargs.update(parsed_kwargs)
return self.plugin.dispatch('command', self, event, **kwargs)
return (event, kwargs)

5
disco/bot/plugin.py

@ -296,7 +296,8 @@ class Plugin(LoggingClass, PluginDeco):
if not event.command.oob:
self.greenlets.add(gevent.getcurrent())
try:
return event.command.execute(event)
command_event, kwargs = event.command.execute(event)
return self.plugin.dispatch('command', event.command, command_event, **kwargs)
except CommandError as e:
event.msg.reply(e.msg)
return False
@ -377,7 +378,7 @@ class Plugin(LoggingClass, PluginDeco):
Keyword arguments to pass onto the :class:`disco.bot.command.Command`
object.
"""
self.commands.append(Command(self, func, *args, **kwargs))
self.commands.append(Command(func, *args, **kwargs))
def register_schedule(self, func, interval, repeat=True, init=True):
"""

11
disco/client.py

@ -92,7 +92,16 @@ class Client(LoggingClass):
self.packets = Emitter()
self.api = APIClient(self.config.token, self)
self.gw = GatewayClient(self, self.config.max_reconnects, self.config.encoder)
self.gw = GatewayClient(
token=self.config.token,
shard_id=self.config.shard_id,
shard_count=self.config.shard_count,
max_reconnects=self.config.max_reconnects,
encoder=self.config.encoder,
events=self.events,
packets=self.packets,
client=self,
)
self.state = State(self, StateConfig(self.config.get('state', {})))
if self.config.manhole_enable:

41
disco/gateway/client.py

@ -3,6 +3,7 @@ import zlib
import six
import ssl
from holster.emitter import Emitter
from websocket import ABNF
from disco.gateway.packets import OPCode, RECV, SEND
@ -19,15 +20,29 @@ ZLIB_SUFFIX = b'\x00\x00\xff\xff'
class GatewayClient(LoggingClass):
GATEWAY_VERSION = 6
def __init__(self, client, max_reconnects=5, encoder='json', zlib_stream_enabled=True, ipc=None):
def __init__(
self,
token,
shard_id=0,
shard_count=1,
max_reconnects=5,
encoder='json',
zlib_stream_enabled=True,
ipc=None,
events=None,
packets=None,
client=None):
super(GatewayClient, self).__init__()
self.client = client
self.token = token
self.shard_id = shard_id
self.shard_count = shard_count
self.max_reconnects = max_reconnects
self.encoder = ENCODERS[encoder]
self.zlib_stream_enabled = zlib_stream_enabled
self.events = client.events
self.packets = client.packets
self.client = client
self.events = events or Emitter()
self.packets = packets or Emitter()
# IPC for shards
if ipc:
@ -88,7 +103,7 @@ class GatewayClient(LoggingClass):
def handle_dispatch(self, packet):
obj = GatewayEvent.from_dispatch(self.client, packet)
self.log.debug('GatewayClient.handle_dispatch %s', obj.__class__.__name__)
self.client.events.emit(obj.__class__.__name__, obj)
self.events.emit(obj.__class__.__name__, obj)
if self.replaying:
self.replayed_events += 1
@ -121,10 +136,12 @@ class GatewayClient(LoggingClass):
def connect_and_run(self, gateway_url=None):
if not gateway_url:
if not self._cached_gateway_url:
if not self._cached_gateway_url and self.client:
self._cached_gateway_url = self.client.api.gateway_get()['url']
gateway_url = self._cached_gateway_url
else:
self._cached_gateway_url = gateway_url
gateway_url += '?v={}&encoding={}'.format(self.GATEWAY_VERSION, self.encoder.TYPE)
@ -191,19 +208,19 @@ class GatewayClient(LoggingClass):
self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq)
self.replaying = True
self.send(OPCode.RESUME, {
'token': self.client.config.token,
'token': self.token,
'session_id': self.session_id,
'seq': self.seq,
})
else:
self.log.info('WS Opened: sending identify payload')
self.send(OPCode.IDENTIFY, {
'token': self.client.config.token,
'token': self.token,
'compress': True,
'large_threshold': 250,
'shard': [
int(self.client.config.shard_id),
int(self.client.config.shard_count),
int(self.shard_id),
int(self.shard_count),
],
'properties': {
'$os': 'linux',
@ -247,6 +264,6 @@ class GatewayClient(LoggingClass):
# Reconnect
self.connect_and_run()
def run(self):
gevent.spawn(self.connect_and_run)
def run(self, gateway_url=None):
gevent.spawn(self.connect_and_run, gateway_url=gateway_url)
self.ws_event.wait()

3
disco/gateway/events.py

@ -1,6 +1,7 @@
from __future__ import print_function
import six
import copy
from disco.types.user import User, Presence
from disco.types.channel import Channel, PermissionOverwrite
@ -48,7 +49,7 @@ class GatewayEvent(six.with_metaclass(GatewayEventMeta, Model)):
"""
Create this GatewayEvent class from data and the client.
"""
cls.raw_data = obj
cls.raw_data = copy.deepcopy(obj)
# If this event is wrapping a model, pull its fields
if hasattr(cls, '_wraps_model'):

Loading…
Cancel
Save