Browse Source

Refine autosharding/IPC

pull/9/head
Andrei 9 years ago
parent
commit
aee7fa13e1
  1. 4
      disco/api/client.py
  2. 4
      disco/cli.py
  3. 2
      disco/client.py
  4. 100
      disco/gateway/ipc/gipc.py
  5. 48
      disco/gateway/sharder.py
  6. 48
      disco/types/message.py

4
disco/api/client.py

@ -26,11 +26,11 @@ class APIClient(LoggingClass):
An abstraction over the :class:`disco.api.http.HTTPClient` that composes requests, and fits An abstraction over the :class:`disco.api.http.HTTPClient` that composes requests, and fits
the models with the returned data. the models with the returned data.
""" """
def __init__(self, client): def __init__(self, token, client=None):
super(APIClient, self).__init__() super(APIClient, self).__init__()
self.client = client self.client = client
self.http = HTTPClient(self.client.config.token) self.http = HTTPClient(token)
def gateway_get(self): def gateway_get(self):
data = self.http(Routes.GATEWAY_GET).json() data = self.http(Routes.GATEWAY_GET).json()

4
disco/cli.py

@ -25,8 +25,6 @@ parser.add_argument('--encoder', help='encoder for gateway data', default=None)
parser.add_argument('--run-bot', help='run a disco bot on this client', action='store_true', default=False) parser.add_argument('--run-bot', help='run a disco bot on this client', action='store_true', default=False)
parser.add_argument('--plugin', help='load plugins into the bot', nargs='*', default=[]) parser.add_argument('--plugin', help='load plugins into the bot', nargs='*', default=[])
logging.basicConfig(level=logging.INFO)
def disco_main(run=False): def disco_main(run=False):
""" """
@ -62,6 +60,8 @@ def disco_main(run=False):
AutoSharder(config).run() AutoSharder(config).run()
return return
logging.basicConfig(level=logging.INFO)
client = Client(config) client = Client(config)
bot = None bot = None

2
disco/client.py

@ -82,7 +82,7 @@ class Client(object):
self.events = Emitter(gevent.spawn) self.events = Emitter(gevent.spawn)
self.packets = Emitter(gevent.spawn) self.packets = Emitter(gevent.spawn)
self.api = APIClient(self) self.api = APIClient(self.config.token, self)
self.gw = GatewayClient(self, self.config.encoder) self.gw = GatewayClient(self, self.config.encoder)
self.state = State(self, StateConfig(self.config.get('state', {}))) self.state = State(self, StateConfig(self.config.get('state', {})))

100
disco/gateway/ipc/gipc.py

@ -1,50 +1,92 @@
import random import random
import gipc
import gevent import gevent
import string import string
import weakref import weakref
import marshal
import types
from holster.enum import Enum
from disco.util.logging import LoggingClass
def get_random_str(size): def get_random_str(size):
return ''.join([random.choice(string.ascii_printable) for _ in range(size)]) return ''.join([random.choice(string.printable) for _ in range(size)])
IPCMessageType = Enum(
'CALL_FUNC',
'GET_ATTR',
'EXECUTE',
'RESPONSE',
)
class GIPCProxy(object): class GIPCProxy(LoggingClass):
def __init__(self, pipe): def __init__(self, obj, pipe):
super(GIPCProxy, self).__init__()
self.obj = obj
self.pipe = pipe self.pipe = pipe
self.results = weakref.WeakValueDictionary() self.results = weakref.WeakValueDictionary()
gevent.spawn(self.read_loop) gevent.spawn(self.read_loop)
def read_loop(self): def resolve(self, parts):
while True: base = self.obj
nonce, data = self.pipe.get() for part in parts:
if nonce in self.results: base = getattr(base, part)
self.results[nonce].set(data)
def __getattr__(self, name): return base
def wrapper(*args, **kwargs):
nonce = get_random_str()
self.results[nonce] = gevent.event.AsyncResult()
self.pipe.put(nonce, name, args, kwargs)
return self.results[nonce]
return wrapper
def send(self, typ, data):
self.pipe.put((typ.value, data))
class GIPCObject(object): def handle(self, mtype, data):
def __init__(self, inst, pipe): if mtype == IPCMessageType.CALL_FUNC:
self.inst = inst nonce, func, args, kwargs = data
self.pipe = pipe res = self.resolve(func)(*args, **kwargs)
gevent.spawn(self.read_loop) self.send(IPCMessageType.RESPONSE, (nonce, res))
elif mtype == IPCMessageType.GET_ATTR:
nonce, path = data
self.send(IPCMessageType.RESPONSE, (nonce, self.resolve(path)))
elif mtype == IPCMessageType.EXECUTE:
nonce, raw = data
func = types.FunctionType(marshal.loads(raw), globals(), nonce)
try:
result = func(self.obj)
except Exception as e:
self.log.exception('Failed to EXECUTE: ')
result = None
self.send(IPCMessageType.RESPONSE, (nonce, result))
elif mtype == IPCMessageType.RESPONSE:
nonce, res = data
if nonce in self.results:
self.results[nonce].set(res)
def read_loop(self): def read_loop(self):
while True: while True:
nonce, func, args, kwargs = self.pipe.get() mtype, data = self.pipe.get()
func = getattr(self.inst, func)
self.pipe.put((nonce, func(*args, **kwargs))) try:
self.handle(mtype, data)
except:
self.log.exception('Error in GIPCProxy:')
def execute(self, func):
nonce = get_random_str(32)
raw = marshal.dumps(func.func_code)
self.results[nonce] = result = gevent.event.AsyncResult()
self.pipe.put((IPCMessageType.EXECUTE.value, (nonce, raw)))
return result
class IPC(object): def get(self, path):
def __init__(self, sharder): nonce = get_random_str(32)
self.sharder = sharder self.results[nonce] = result = gevent.event.AsyncResult()
self.pipe.put((IPCMessageType.GET_ATTR.value, (nonce, path)))
return result
def get_shards(self): def call(self, path, *args, **kwargs):
return {} nonce = get_random_str(32)
self.results[nonce] = result = gevent.event.AsyncResult()
self.pipe.put((IPCMessageType.CALL_FUNC.value, (nonce, path, args, kwargs)))
return result

48
disco/gateway/sharder.py

@ -1,16 +1,46 @@
from __future__ import absolute_import
import gipc import gipc
import gevent
import types
import marshal
from disco.client import Client from disco.client import Client
from disco.bot import Bot, BotConfig from disco.bot import Bot, BotConfig
from disco.api.client import APIClient from disco.api.client import APIClient
from disco.gateway.ipc.gipc import GIPCObject, GIPCProxy from disco.gateway.ipc.gipc import GIPCProxy
def run_on(id, proxy):
def f(func):
return proxy.call(('run_on', ), id, marshal.dumps(func.func_code))
return f
def run_self(bot):
def f(func):
result = gevent.event.AsyncResult()
result.set(func(bot))
return result
return f
def run_shard(config, id, pipe): def run_shard(config, id, pipe):
import logging
logging.basicConfig(
level=logging.INFO,
format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(id)
)
config.shard_id = id config.shard_id = id
client = Client(config) client = Client(config)
bot = Bot(client, BotConfig(config.bot)) bot = Bot(client, BotConfig(config.bot))
GIPCObject(bot, pipe) bot.sharder = GIPCProxy(bot, pipe)
bot.shards = {
i: run_on(i, bot.sharder) for i in range(config.shard_count)
if i != id
}
bot.shards[id] = run_self(bot)
bot.run_forever() bot.run_forever()
@ -20,12 +50,22 @@ class AutoSharder(object):
self.client = APIClient(config.token) self.client = APIClient(config.token)
self.shards = {} self.shards = {}
self.config.shard_count = self.client.gateway_bot_get()['shards'] self.config.shard_count = self.client.gateway_bot_get()['shards']
self.config.shard_count = 10
self.test = 1
def run_on(self, id, funccode):
func = types.FunctionType(marshal.loads(funccode), globals(), '_run_on_temp')
return self.shards[id].execute(func).wait(timeout=15)
def run(self): def run(self):
for shard in range(self.shard_count): for shard in range(self.config.shard_count):
if self.config.manhole_enable and shard != 0:
self.config.manhole_enable = False
self.start_shard(shard) self.start_shard(shard)
gevent.sleep(6)
def start_shard(self, id): def start_shard(self, id):
cpipe, ppipe = gipc.pipe(duplex=True) cpipe, ppipe = gipc.pipe(duplex=True)
gipc.start_process(run_shard, (self.config, id, cpipe)) gipc.start_process(run_shard, (self.config, id, cpipe))
self.shards[id] = GIPCProxy(ppipe) self.shards[id] = GIPCProxy(self, ppipe)

48
disco/types/message.py

@ -300,3 +300,51 @@ class Message(SlottedModel):
return user_replace(self.mentions.get(id)) return user_replace(self.mentions.get(id))
return re.sub('<@!?([0-9]+)>', replace, self.content) return re.sub('<@!?([0-9]+)>', replace, self.content)
class MessageTable(object):
def __init__(self, sep=' | ', codeblock=True, header_break=True):
self.header = []
self.entries = []
self.size_index = {}
self.sep = sep
self.codeblock = codeblock
self.header_break = header_break
def recalculate_size_index(self, cols):
for idx, col in enumerate(cols):
if idx not in self.size_index or len(col) > self.size_index[idx]:
self.size_index[idx] = len(col)
def set_header(self, *args):
self.header = args
self.recalculate_size_index(args)
def add(self, *args):
args = list(map(str, args))
self.entries.append(args)
self.recalculate_size_index(args)
def compile_one(self, cols):
data = self.sep.lstrip()
for idx, col in enumerate(cols):
padding = ' ' * ((self.size_index[idx] - len(col)))
data += col + padding + self.sep
return data.rstrip()
def compile(self):
data = []
data.append(self.compile_one(self.header))
if self.header_break:
data.append('-' * (sum(self.size_index.values()) + (len(self.header) * len(self.sep)) + 1))
for row in self.entries:
data.append(self.compile_one(row))
if self.codeblock:
return '```' + '\n'.join(data) + '```'
return '\n'.join(data)

Loading…
Cancel
Save