commit
c12aa5344f
14 changed files with 508 additions and 0 deletions
@ -0,0 +1 @@ |
|||||
|
VERSION = '0.0.1' |
@ -0,0 +1,15 @@ |
|||||
|
from disco.api.http import Routes, HTTPClient |
||||
|
|
||||
|
from disco.util.logging import LoggingClass |
||||
|
|
||||
|
|
||||
|
class APIClient(LoggingClass): |
||||
|
def __init__(self, client): |
||||
|
super(APIClient, self).__init__() |
||||
|
|
||||
|
self.client = client |
||||
|
self.http = HTTPClient(self.client.token) |
||||
|
|
||||
|
def gateway(self, version, encoding): |
||||
|
r = self.http(Routes.GATEWAY_GET) |
||||
|
return r['url'] + '?v={}&encoding={}'.format(version, encoding) |
@ -0,0 +1,50 @@ |
|||||
|
import requests |
||||
|
|
||||
|
from holster.enum import Enum |
||||
|
|
||||
|
HTTPMethod = Enum( |
||||
|
GET='GET', |
||||
|
POST='POST', |
||||
|
PUT='PUT', |
||||
|
PATCH='PATCH', |
||||
|
DELETE='DELETE', |
||||
|
) |
||||
|
|
||||
|
|
||||
|
class Routes(object): |
||||
|
USERS_ME_GET = (HTTPMethod.GET, '/users/@me') |
||||
|
USERS_ME_PATCH = (HTTPMethod.PATCH, '/users/@me') |
||||
|
|
||||
|
GATEWAY_GET = (HTTPMethod.GET, '/gateway') |
||||
|
|
||||
|
|
||||
|
class APIException(Exception): |
||||
|
def __init__(self, obj): |
||||
|
self.code = obj['code'] |
||||
|
self.msg = obj['msg'] |
||||
|
|
||||
|
super(APIException, self).__init__(self.msg) |
||||
|
|
||||
|
|
||||
|
class HTTPClient(object): |
||||
|
BASE_URL = 'https://discordapp.com/api' |
||||
|
|
||||
|
def __init__(self, token): |
||||
|
self.headers = { |
||||
|
'Authorization': 'Bot ' + token, |
||||
|
} |
||||
|
|
||||
|
def __call__(self, route, *args, **kwargs): |
||||
|
method, url = route |
||||
|
|
||||
|
r = requests.request(str(method), self.BASE_URL + url, *args, **kwargs) |
||||
|
|
||||
|
try: |
||||
|
r.raise_for_status() |
||||
|
except: |
||||
|
# TODO: rate limits |
||||
|
# TODO: check json |
||||
|
raise APIException(r.json()) |
||||
|
|
||||
|
# TODO: check json |
||||
|
return r.json() |
@ -0,0 +1,26 @@ |
|||||
|
import logging |
||||
|
import argparse |
||||
|
|
||||
|
from gevent import monkey |
||||
|
|
||||
|
parser = argparse.ArgumentParser() |
||||
|
parser.add_argument('--token', help='Bot Authentication Token', required=True) |
||||
|
|
||||
|
logging.basicConfig(level=logging.INFO) |
||||
|
|
||||
|
|
||||
|
def main(): |
||||
|
monkey.patch_all() |
||||
|
args = parser.parse_args() |
||||
|
|
||||
|
from disco.util.token import is_valid_token |
||||
|
|
||||
|
if not is_valid_token(args.token): |
||||
|
print 'Invalid token passed' |
||||
|
return |
||||
|
|
||||
|
from disco.client import DiscoClient |
||||
|
DiscoClient(args.token).run_forever() |
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
main() |
@ -0,0 +1,22 @@ |
|||||
|
import logging |
||||
|
|
||||
|
from disco.api.client import APIClient |
||||
|
from disco.gateway.client import GatewayClient |
||||
|
|
||||
|
log = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
class DiscoClient(object): |
||||
|
def __init__(self, token, sharding=None): |
||||
|
self.log = log |
||||
|
self.token = token |
||||
|
self.sharding = sharding or {'number': 0, 'total': 1} |
||||
|
|
||||
|
self.api = APIClient(self) |
||||
|
self.gw = GatewayClient(self) |
||||
|
|
||||
|
def run(self): |
||||
|
return self.gw.run() |
||||
|
|
||||
|
def run_forever(self): |
||||
|
return self.gw.run().join() |
@ -0,0 +1,121 @@ |
|||||
|
import websocket |
||||
|
import gevent |
||||
|
import json |
||||
|
|
||||
|
from disco.gateway.packets import ( |
||||
|
Packet, DispatchPacket, HeartbeatPacket, ReconnectPacket, InvalidSessionPacket, HelloPacket, HeartbeatAckPacket, |
||||
|
ResumePacket, IdentifyPacket) |
||||
|
from disco.util.logging import LoggingClass |
||||
|
|
||||
|
GATEWAY_VERSION = 6 |
||||
|
|
||||
|
|
||||
|
def log_error(log, msg, w): |
||||
|
def _f(*args, **kwargs): |
||||
|
try: |
||||
|
return w(*args, **kwargs) |
||||
|
except: |
||||
|
log.exception(msg) |
||||
|
raise |
||||
|
return _f |
||||
|
|
||||
|
|
||||
|
class GatewayClient(LoggingClass): |
||||
|
def __init__(self, client): |
||||
|
super(GatewayClient, self).__init__() |
||||
|
self.client = client |
||||
|
|
||||
|
# Websocket connection |
||||
|
self.ws = None |
||||
|
|
||||
|
# State |
||||
|
self.seq = 0 |
||||
|
self.session_id = None |
||||
|
|
||||
|
# Cached gateway URL |
||||
|
self._cached_gateway_url = None |
||||
|
|
||||
|
# Heartbeat |
||||
|
self._heartbeat_task = None |
||||
|
|
||||
|
self._fatal_error_promise = gevent.event.AsyncResult() |
||||
|
|
||||
|
def send(self, packet): |
||||
|
self.ws.send(json.dumps({ |
||||
|
'op': int(packet.OP), |
||||
|
'd': packet.to_dict(), |
||||
|
})) |
||||
|
|
||||
|
def heartbeat_task(self, interval): |
||||
|
while True: |
||||
|
self.send(HeartbeatPacket(data=self.seq)) |
||||
|
gevent.sleep(interval / 1000) |
||||
|
|
||||
|
def handle_hello(self, packet): |
||||
|
self.log.info('Recieved HELLO, starting heartbeater...') |
||||
|
self._heartbeat_task = gevent.spawn(self.heartbeat_task, packet.heartbeat_interval) |
||||
|
|
||||
|
def connect(self): |
||||
|
if not self._cached_gateway_url: |
||||
|
self._cached_gateway_url = self.client.api.gateway(version=GATEWAY_VERSION, encoding='json') |
||||
|
|
||||
|
self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url) |
||||
|
self.ws = websocket.WebSocketApp( |
||||
|
self._cached_gateway_url, |
||||
|
on_message=log_error(self.log, 'Error in on_message:', self.on_message), |
||||
|
on_error=log_error(self.log, 'Error in on_error:', self.on_error), |
||||
|
on_open=log_error(self.log, 'Error in on_open:', self.on_open), |
||||
|
on_close=log_error(self.log, 'Error in on_close:', self.on_close), |
||||
|
) |
||||
|
|
||||
|
def on_message(self, ws, msg): |
||||
|
# TODO: ZLIB |
||||
|
|
||||
|
try: |
||||
|
packet = Packet.load_json(json.loads(msg)) |
||||
|
if packet.seq and packet.seq > self.seq: |
||||
|
self.seq = packet.seq |
||||
|
except: |
||||
|
self.log.exception('Failed to load dispatch:') |
||||
|
return |
||||
|
|
||||
|
if isinstance(packet, DispatchPacket): |
||||
|
self.handle_dispatch(packet) |
||||
|
elif isinstance(packet, HeartbeatPacket): |
||||
|
self.handle_heartbeat(packet) |
||||
|
elif isinstance(packet, ReconnectPacket): |
||||
|
self.handle_reconnect(packet) |
||||
|
elif isinstance(packet, InvalidSessionPacket): |
||||
|
self.handle_invalid_session(packet) |
||||
|
elif isinstance(packet, HelloPacket): |
||||
|
self.handle_hello(packet) |
||||
|
elif isinstance(packet, HeartbeatAckPacket): |
||||
|
self.handle_heartbeat_ack(packet) |
||||
|
else: |
||||
|
raise Exception('Unknown packet: {}'.format(packet)) |
||||
|
|
||||
|
def on_error(self, ws, error): |
||||
|
print 'error', error |
||||
|
|
||||
|
def on_open(self, ws): |
||||
|
print 'on open' |
||||
|
if self.seq and self.session_id: |
||||
|
self.send(ResumePacket(seq=self.seq, session_id=self.session_id, token=self.client.token)) |
||||
|
else: |
||||
|
self.send(IdentifyPacket( |
||||
|
token=self.client.token, |
||||
|
compress=True, |
||||
|
large_threshold=250, |
||||
|
shard=[self.client.sharding['number'], self.client.sharding['total']])) |
||||
|
|
||||
|
def on_close(self, ws): |
||||
|
print 'close' |
||||
|
|
||||
|
def run(self): |
||||
|
self.connect() |
||||
|
|
||||
|
# Spawn a thread to run the connection loop forever |
||||
|
gevent.spawn(self.ws.run_forever) |
||||
|
|
||||
|
# Wait for a fatal error |
||||
|
self._fatal_error_promise.get() |
@ -0,0 +1,116 @@ |
|||||
|
from holster.enum import Enum |
||||
|
|
||||
|
from disco.util.oop import TypedClass |
||||
|
|
||||
|
OPCode = Enum( |
||||
|
DISPATCH=0, |
||||
|
HEARTBEAT=1, |
||||
|
IDENTIFY=2, |
||||
|
STATUS_UPDATE=3, |
||||
|
VOICE_STATE_UPDATE=4, |
||||
|
VOICE_SERVER_PING=5, |
||||
|
RESUME=6, |
||||
|
RECONNECT=7, |
||||
|
REQUEST_GUILD_MEMBERS=8, |
||||
|
INVALID_SESSION=9, |
||||
|
HELLO=10, |
||||
|
HEARTBEAT_ACK=11, |
||||
|
GUILD_SYNC=12, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
class Packet(TypedClass): |
||||
|
@classmethod |
||||
|
def load_json(cls, obj): |
||||
|
if not obj['op']: |
||||
|
raise Exception('Packet struct missing op key: {}'.format(obj)) |
||||
|
|
||||
|
cls = PACKETS.get(obj['op']) |
||||
|
|
||||
|
if not cls: |
||||
|
raise Exception('Unknown OPCode: {}'.format(obj['op'])) |
||||
|
|
||||
|
obj.update(obj['d']) |
||||
|
del obj['d'] |
||||
|
inst = cls.from_dict(obj) |
||||
|
inst.seq = obj['s'] |
||||
|
return inst |
||||
|
|
||||
|
|
||||
|
class DispatchPacket(Packet): |
||||
|
OP = OPCode.DISPATCH |
||||
|
|
||||
|
PARAMS = { |
||||
|
('d', 'data'): {}, |
||||
|
('t', 'event'): str, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
class HeartbeatPacket(Packet): |
||||
|
OP = OPCode.HEARTBEAT |
||||
|
|
||||
|
PARAMS = { |
||||
|
('d', 'data'): (int, ), |
||||
|
} |
||||
|
|
||||
|
|
||||
|
class IdentifyPacket(Packet): |
||||
|
OP = OPCode.IDENTIFY |
||||
|
|
||||
|
PARAMS = { |
||||
|
'token': str, |
||||
|
'compress': bool, |
||||
|
'large_threshold': int, |
||||
|
'shard': [int], |
||||
|
'properties': 'properties' |
||||
|
} |
||||
|
|
||||
|
@property |
||||
|
def properties(self): |
||||
|
return { |
||||
|
'$os': 'linux', |
||||
|
'$browser': 'disco', |
||||
|
'$device': 'disco', |
||||
|
'$referrer': '', |
||||
|
} |
||||
|
|
||||
|
|
||||
|
class ResumePacket(Packet): |
||||
|
OP = OPCode.RESUME |
||||
|
|
||||
|
PARAMS = { |
||||
|
'token': str, |
||||
|
'session_id': str, |
||||
|
'seq': int, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
class ReconnectPacket(Packet): |
||||
|
OP = OPCode.RECONNECT |
||||
|
|
||||
|
|
||||
|
class InvalidSessionPacket(Packet): |
||||
|
OP = OPCode.INVALID_SESSION |
||||
|
|
||||
|
|
||||
|
class HelloPacket(Packet): |
||||
|
OP = OPCode.HELLO |
||||
|
|
||||
|
PARAMS = { |
||||
|
'heartbeat_interval': int, |
||||
|
'_trace': [str], |
||||
|
} |
||||
|
|
||||
|
|
||||
|
class HeartbeatAckPacket(Packet): |
||||
|
OP = OPCode.HEARTBEAT_ACK |
||||
|
|
||||
|
|
||||
|
PACKETS = { |
||||
|
int(OPCode.DISPATCH): DispatchPacket, |
||||
|
int(OPCode.HEARTBEAT): HeartbeatPacket, |
||||
|
int(OPCode.RECONNECT): ReconnectPacket, |
||||
|
int(OPCode.INVALID_SESSION): InvalidSessionPacket, |
||||
|
int(OPCode.HELLO): HelloPacket, |
||||
|
int(OPCode.HEARTBEAT_ACK): HeartbeatAckPacket, |
||||
|
} |
@ -0,0 +1,8 @@ |
|||||
|
from __future__ import absolute_import |
||||
|
|
||||
|
import logging |
||||
|
|
||||
|
|
||||
|
class LoggingClass(object): |
||||
|
def __init__(self): |
||||
|
self.log = logging.getLogger(self.__class__.__name__) |
@ -0,0 +1,139 @@ |
|||||
|
import inspect |
||||
|
|
||||
|
|
||||
|
class TypedClassException(Exception): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
def construct_typed_class(cls, data): |
||||
|
obj = cls() |
||||
|
load_typed_class(obj, data) |
||||
|
return obj |
||||
|
|
||||
|
|
||||
|
def get_field_and_alias(field): |
||||
|
if isinstance(field, tuple): |
||||
|
return field |
||||
|
else: |
||||
|
return field, field |
||||
|
|
||||
|
|
||||
|
def get_optional(typ): |
||||
|
if isinstance(typ, tuple) and len(typ) == 1: |
||||
|
return True, typ[0] |
||||
|
return False, typ |
||||
|
|
||||
|
|
||||
|
def cast(typ, value): |
||||
|
valid = True |
||||
|
|
||||
|
# TODO: better exceptions |
||||
|
if isinstance(typ, list): |
||||
|
if typ: |
||||
|
typ = typ[0] |
||||
|
value = map(typ, value) |
||||
|
else: |
||||
|
list(value) |
||||
|
elif isinstance(typ, dict): |
||||
|
if typ: |
||||
|
ktyp, vtyp = typ.items()[0] |
||||
|
value = {ktyp(k): vtyp(v) for k, v in typ.items()} |
||||
|
else: |
||||
|
dict(value) |
||||
|
elif isinstance(typ, set): |
||||
|
if typ: |
||||
|
typ = list(typ)[0] |
||||
|
value = set(map(typ, value)) |
||||
|
else: |
||||
|
set(value) |
||||
|
elif isinstance(typ, str): |
||||
|
valid = False |
||||
|
elif not isinstance(value, typ): |
||||
|
value = typ(value) |
||||
|
|
||||
|
return valid, value |
||||
|
|
||||
|
|
||||
|
def load_typed_class(obj, params, data): |
||||
|
print obj, params, data |
||||
|
for field, typ in params.items(): |
||||
|
field, alias = get_field_and_alias(field) |
||||
|
|
||||
|
# Skipped field |
||||
|
if typ is None: |
||||
|
continue |
||||
|
|
||||
|
optional, typ = get_optional(typ) |
||||
|
if field not in data and not optional: |
||||
|
raise TypedClassException('Missing value for attribute `{}`'.format(field)) |
||||
|
|
||||
|
value = data[field] |
||||
|
|
||||
|
print field, alias, value, typ |
||||
|
if value is None: |
||||
|
if not optional: |
||||
|
raise TypedClassException('Non-optional attribute `{}` cannot take None'.format(field)) |
||||
|
else: |
||||
|
valid, value = cast(typ, value) |
||||
|
if not valid: |
||||
|
continue |
||||
|
|
||||
|
setattr(obj, alias, value) |
||||
|
|
||||
|
|
||||
|
def dump_typed_class(obj, params): |
||||
|
data = {} |
||||
|
|
||||
|
for field, typ in params.items(): |
||||
|
field, alias = get_field_and_alias(field) |
||||
|
|
||||
|
value = getattr(obj, alias, None) |
||||
|
|
||||
|
if typ is None: |
||||
|
data[field] = typ |
||||
|
continue |
||||
|
|
||||
|
optional, typ = get_optional(typ) |
||||
|
if not value and not optional: |
||||
|
raise TypedClassException('Missing value for attribute `{}`'.format(field)) |
||||
|
|
||||
|
_, value = cast(typ, value) |
||||
|
data[field] = value |
||||
|
|
||||
|
return data |
||||
|
|
||||
|
|
||||
|
def get_params(obj): |
||||
|
assert(issubclass(obj.__class__, TypedClass)) |
||||
|
|
||||
|
if not hasattr(obj.__class__, '_cached_oop_params'): |
||||
|
base = {} |
||||
|
for cls in reversed(inspect.getmro(obj.__class__)): |
||||
|
base.update(getattr(cls, 'PARAMS', {})) |
||||
|
obj.__class__._cached_oop_params = base |
||||
|
return obj.__class__._cached_oop_params |
||||
|
|
||||
|
|
||||
|
def load_typed_instance(obj, data): |
||||
|
return load_typed_class(obj, get_params(obj), data) |
||||
|
|
||||
|
|
||||
|
class TypedClass(object): |
||||
|
def __init__(self, **kwargs): |
||||
|
# TODO: validate |
||||
|
self.__dict__.update(kwargs) |
||||
|
|
||||
|
@classmethod |
||||
|
def from_dict(cls, data): |
||||
|
self = cls() |
||||
|
load_typed_instance(self, data) |
||||
|
return self |
||||
|
|
||||
|
def to_dict(self): |
||||
|
return dump_typed_class(self, get_params(self)) |
||||
|
|
||||
|
|
||||
|
def require_implementation(attr): |
||||
|
def _f(self, *args, **kwargs): |
||||
|
raise NotImplementedError('{} must implement method {}', self.__class__.__name, attr) |
||||
|
return _f |
@ -0,0 +1,10 @@ |
|||||
|
import re |
||||
|
|
||||
|
TOKEN_RE = re.compile(r'M\w{23}\.[\w-]{6}\..{27}') |
||||
|
|
||||
|
|
||||
|
def is_valid_token(token): |
||||
|
""" |
||||
|
Validates a Discord authentication token, returning true if valid |
||||
|
""" |
||||
|
return bool(TOKEN_RE.match(token)) |
Loading…
Reference in new issue