Compare commits

...

1 Commits

Author SHA1 Message Date
andrei c8994b203e Play around with some better abstractions 7 years ago
  1. 4
      disco/bot/bot.py
  2. 25
      disco/bot/command.py
  3. 5
      disco/bot/plugin.py
  4. 11
      disco/client.py
  5. 41
      disco/gateway/client.py
  6. 3
      disco/gateway/events.py

4
disco/bot/bot.py

@ -258,7 +258,7 @@ class Bot(LoggingClass):
Computes a single regex which matches all possible command combinations. Computes a single regex which matches all possible command combinations.
""" """
commands = list(self.commands) commands = list(self.commands)
re_str = '|'.join(command.regex(grouped=False) for command in commands) re_str = '|'.join(command.regex(self.group_abbrev, grouped=False) for command in commands)
if re_str: if re_str:
self.command_matches_re = re.compile(re_str, re.I) self.command_matches_re = re.compile(re_str, re.I)
else: else:
@ -326,7 +326,7 @@ class Bot(LoggingClass):
options = [] options = []
for command in self.commands: for command in self.commands:
match = command.compiled_regex.match(content) match = command.compiled_regex(self.group_abbrev).match(content)
if match: if match:
options.append((command, match)) options.append((command, match))
return sorted(options, key=lambda obj: obj[0].group is None) return sorted(options, key=lambda obj: obj[0].group is None)

25
disco/bot/command.py

@ -120,8 +120,6 @@ class Command(object):
Attributes Attributes
---------- ----------
plugin : :class:`disco.bot.plugin.Plugin`
The plugin this command is a member of.
func : function func : function
The function which is called when this command is triggered. The function which is called when this command is triggered.
trigger : str trigger : str
@ -135,8 +133,7 @@ class Command(object):
is_regex : Optional[bool] is_regex : Optional[bool]
Whether the triggers for this command should be treated as raw regex. Whether the triggers for this command should be treated as raw regex.
""" """
def __init__(self, plugin, func, trigger, *args, **kwargs): def __init__(self, func, trigger, *args, **kwargs):
self.plugin = plugin
self.func = func self.func = func
self.triggers = [trigger] self.triggers = [trigger]
@ -216,6 +213,8 @@ class Command(object):
if parser: if parser:
self.parser = PluginArgumentParser(prog=self.name, add_help=False) self.parser = PluginArgumentParser(prog=self.name, add_help=False)
self._cached_regex = None
@staticmethod @staticmethod
def mention_type(getters, reg=None, user=False, allow_plain=False): def mention_type(getters, reg=None, user=False, allow_plain=False):
def _f(ctx, raw): def _f(ctx, raw):
@ -244,14 +243,12 @@ class Command(object):
raise TypeError('Cannot resolve mention: {}'.format(raw)) raise TypeError('Cannot resolve mention: {}'.format(raw))
return _f return _f
@simple_cached_property def compiled_regex(self, group_abbrev):
def compiled_regex(self): if not self._cached_regex:
""" self._cached_regex = re.compile(self.regex(group_abbrev), re.I)
A compiled version of this command's regex. return self._cached_regex
"""
return re.compile(self.regex(), re.I)
def regex(self, grouped=True): def regex(self, group_abbrev, grouped=True):
""" """
The regex string that defines/triggers this command. The regex string that defines/triggers this command.
""" """
@ -260,8 +257,8 @@ class Command(object):
else: else:
group = '' group = ''
if self.group: if self.group:
if self.group in self.plugin.bot.group_abbrev: if self.group in group_abbrev:
rest = self.plugin.bot.group_abbrev[self.group] rest = group_abbrev[self.group]
group = '{}(?:{}) '.format(rest, ''.join(c + u'?' for c in self.group[len(rest):])) group = '{}(?:{}) '.format(rest, ''.join(c + u'?' for c in self.group[len(rest):]))
else: else:
group = self.group + ' ' group = self.group + ' '
@ -303,4 +300,4 @@ class Command(object):
kwargs = {} kwargs = {}
kwargs.update(self.context) kwargs.update(self.context)
kwargs.update(parsed_kwargs) kwargs.update(parsed_kwargs)
return self.plugin.dispatch('command', self, event, **kwargs) return (event, kwargs)

5
disco/bot/plugin.py

@ -296,7 +296,8 @@ class Plugin(LoggingClass, PluginDeco):
if not event.command.oob: if not event.command.oob:
self.greenlets.add(gevent.getcurrent()) self.greenlets.add(gevent.getcurrent())
try: try:
return event.command.execute(event) command_event, kwargs = event.command.execute(event)
return self.plugin.dispatch('command', event.command, command_event, **kwargs)
except CommandError as e: except CommandError as e:
event.msg.reply(e.msg) event.msg.reply(e.msg)
return False return False
@ -377,7 +378,7 @@ class Plugin(LoggingClass, PluginDeco):
Keyword arguments to pass onto the :class:`disco.bot.command.Command` Keyword arguments to pass onto the :class:`disco.bot.command.Command`
object. object.
""" """
self.commands.append(Command(self, func, *args, **kwargs)) self.commands.append(Command(func, *args, **kwargs))
def register_schedule(self, func, interval, repeat=True, init=True): def register_schedule(self, func, interval, repeat=True, init=True):
""" """

11
disco/client.py

@ -92,7 +92,16 @@ class Client(LoggingClass):
self.packets = Emitter() self.packets = Emitter()
self.api = APIClient(self.config.token, self) self.api = APIClient(self.config.token, self)
self.gw = GatewayClient(self, self.config.max_reconnects, self.config.encoder) self.gw = GatewayClient(
token=self.config.token,
shard_id=self.config.shard_id,
shard_count=self.config.shard_count,
max_reconnects=self.config.max_reconnects,
encoder=self.config.encoder,
events=self.events,
packets=self.packets,
client=self,
)
self.state = State(self, StateConfig(self.config.get('state', {}))) self.state = State(self, StateConfig(self.config.get('state', {})))
if self.config.manhole_enable: if self.config.manhole_enable:

41
disco/gateway/client.py

@ -3,6 +3,7 @@ import zlib
import six import six
import ssl import ssl
from holster.emitter import Emitter
from websocket import ABNF from websocket import ABNF
from disco.gateway.packets import OPCode, RECV, SEND from disco.gateway.packets import OPCode, RECV, SEND
@ -19,15 +20,29 @@ ZLIB_SUFFIX = b'\x00\x00\xff\xff'
class GatewayClient(LoggingClass): class GatewayClient(LoggingClass):
GATEWAY_VERSION = 6 GATEWAY_VERSION = 6
def __init__(self, client, max_reconnects=5, encoder='json', zlib_stream_enabled=True, ipc=None): def __init__(
self,
token,
shard_id=0,
shard_count=1,
max_reconnects=5,
encoder='json',
zlib_stream_enabled=True,
ipc=None,
events=None,
packets=None,
client=None):
super(GatewayClient, self).__init__() super(GatewayClient, self).__init__()
self.client = client self.token = token
self.shard_id = shard_id
self.shard_count = shard_count
self.max_reconnects = max_reconnects self.max_reconnects = max_reconnects
self.encoder = ENCODERS[encoder] self.encoder = ENCODERS[encoder]
self.zlib_stream_enabled = zlib_stream_enabled self.zlib_stream_enabled = zlib_stream_enabled
self.events = client.events self.client = client
self.packets = client.packets self.events = events or Emitter()
self.packets = packets or Emitter()
# IPC for shards # IPC for shards
if ipc: if ipc:
@ -88,7 +103,7 @@ class GatewayClient(LoggingClass):
def handle_dispatch(self, packet): def handle_dispatch(self, packet):
obj = GatewayEvent.from_dispatch(self.client, packet) obj = GatewayEvent.from_dispatch(self.client, packet)
self.log.debug('GatewayClient.handle_dispatch %s', obj.__class__.__name__) self.log.debug('GatewayClient.handle_dispatch %s', obj.__class__.__name__)
self.client.events.emit(obj.__class__.__name__, obj) self.events.emit(obj.__class__.__name__, obj)
if self.replaying: if self.replaying:
self.replayed_events += 1 self.replayed_events += 1
@ -121,10 +136,12 @@ class GatewayClient(LoggingClass):
def connect_and_run(self, gateway_url=None): def connect_and_run(self, gateway_url=None):
if not gateway_url: if not gateway_url:
if not self._cached_gateway_url: if not self._cached_gateway_url and self.client:
self._cached_gateway_url = self.client.api.gateway_get()['url'] self._cached_gateway_url = self.client.api.gateway_get()['url']
gateway_url = self._cached_gateway_url gateway_url = self._cached_gateway_url
else:
self._cached_gateway_url = gateway_url
gateway_url += '?v={}&encoding={}'.format(self.GATEWAY_VERSION, self.encoder.TYPE) gateway_url += '?v={}&encoding={}'.format(self.GATEWAY_VERSION, self.encoder.TYPE)
@ -191,19 +208,19 @@ class GatewayClient(LoggingClass):
self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq) self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq)
self.replaying = True self.replaying = True
self.send(OPCode.RESUME, { self.send(OPCode.RESUME, {
'token': self.client.config.token, 'token': self.token,
'session_id': self.session_id, 'session_id': self.session_id,
'seq': self.seq, 'seq': self.seq,
}) })
else: else:
self.log.info('WS Opened: sending identify payload') self.log.info('WS Opened: sending identify payload')
self.send(OPCode.IDENTIFY, { self.send(OPCode.IDENTIFY, {
'token': self.client.config.token, 'token': self.token,
'compress': True, 'compress': True,
'large_threshold': 250, 'large_threshold': 250,
'shard': [ 'shard': [
int(self.client.config.shard_id), int(self.shard_id),
int(self.client.config.shard_count), int(self.shard_count),
], ],
'properties': { 'properties': {
'$os': 'linux', '$os': 'linux',
@ -247,6 +264,6 @@ class GatewayClient(LoggingClass):
# Reconnect # Reconnect
self.connect_and_run() self.connect_and_run()
def run(self): def run(self, gateway_url=None):
gevent.spawn(self.connect_and_run) gevent.spawn(self.connect_and_run, gateway_url=gateway_url)
self.ws_event.wait() self.ws_event.wait()

3
disco/gateway/events.py

@ -1,6 +1,7 @@
from __future__ import print_function from __future__ import print_function
import six import six
import copy
from disco.types.user import User, Presence from disco.types.user import User, Presence
from disco.types.channel import Channel, PermissionOverwrite from disco.types.channel import Channel, PermissionOverwrite
@ -48,7 +49,7 @@ class GatewayEvent(six.with_metaclass(GatewayEventMeta, Model)):
""" """
Create this GatewayEvent class from data and the client. Create this GatewayEvent class from data and the client.
""" """
cls.raw_data = obj cls.raw_data = copy.deepcopy(obj)
# If this event is wrapping a model, pull its fields # If this event is wrapping a model, pull its fields
if hasattr(cls, '_wraps_model'): if hasattr(cls, '_wraps_model'):

Loading…
Cancel
Save