commit c12aa5344f6e244443df71958a8b020ea8232879 Author: Andrei Date: Wed Sep 21 22:03:31 2016 -0500 :^) diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/disco/__init__.py b/disco/__init__.py new file mode 100644 index 0000000..a4e55ec --- /dev/null +++ b/disco/__init__.py @@ -0,0 +1 @@ +VERSION = '0.0.1' diff --git a/disco/api/__init__.py b/disco/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/disco/api/client.py b/disco/api/client.py new file mode 100644 index 0000000..7470c02 --- /dev/null +++ b/disco/api/client.py @@ -0,0 +1,15 @@ +from disco.api.http import Routes, HTTPClient + +from disco.util.logging import LoggingClass + + +class APIClient(LoggingClass): + def __init__(self, client): + super(APIClient, self).__init__() + + self.client = client + self.http = HTTPClient(self.client.token) + + def gateway(self, version, encoding): + r = self.http(Routes.GATEWAY_GET) + return r['url'] + '?v={}&encoding={}'.format(version, encoding) diff --git a/disco/api/http.py b/disco/api/http.py new file mode 100644 index 0000000..f7eff59 --- /dev/null +++ b/disco/api/http.py @@ -0,0 +1,50 @@ +import requests + +from holster.enum import Enum + +HTTPMethod = Enum( + GET='GET', + POST='POST', + PUT='PUT', + PATCH='PATCH', + DELETE='DELETE', +) + + +class Routes(object): + USERS_ME_GET = (HTTPMethod.GET, '/users/@me') + USERS_ME_PATCH = (HTTPMethod.PATCH, '/users/@me') + + GATEWAY_GET = (HTTPMethod.GET, '/gateway') + + +class APIException(Exception): + def __init__(self, obj): + self.code = obj['code'] + self.msg = obj['msg'] + + super(APIException, self).__init__(self.msg) + + +class HTTPClient(object): + BASE_URL = 'https://discordapp.com/api' + + def __init__(self, token): + self.headers = { + 'Authorization': 'Bot ' + token, + } + + def __call__(self, route, *args, **kwargs): + method, url = route + + r = requests.request(str(method), self.BASE_URL + url, *args, **kwargs) + + try: + r.raise_for_status() + except: + # TODO: rate limits + # TODO: check json + raise APIException(r.json()) + + # TODO: check json + return r.json() diff --git a/disco/cli.py b/disco/cli.py new file mode 100644 index 0000000..90cc701 --- /dev/null +++ b/disco/cli.py @@ -0,0 +1,26 @@ +import logging +import argparse + +from gevent import monkey + +parser = argparse.ArgumentParser() +parser.add_argument('--token', help='Bot Authentication Token', required=True) + +logging.basicConfig(level=logging.INFO) + + +def main(): + monkey.patch_all() + args = parser.parse_args() + + from disco.util.token import is_valid_token + + if not is_valid_token(args.token): + print 'Invalid token passed' + return + + from disco.client import DiscoClient + DiscoClient(args.token).run_forever() + +if __name__ == '__main__': + main() diff --git a/disco/client.py b/disco/client.py new file mode 100644 index 0000000..f2e3aff --- /dev/null +++ b/disco/client.py @@ -0,0 +1,22 @@ +import logging + +from disco.api.client import APIClient +from disco.gateway.client import GatewayClient + +log = logging.getLogger(__name__) + + +class DiscoClient(object): + def __init__(self, token, sharding=None): + self.log = log + self.token = token + self.sharding = sharding or {'number': 0, 'total': 1} + + self.api = APIClient(self) + self.gw = GatewayClient(self) + + def run(self): + return self.gw.run() + + def run_forever(self): + return self.gw.run().join() diff --git a/disco/gateway/__init__.py b/disco/gateway/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/disco/gateway/client.py b/disco/gateway/client.py new file mode 100644 index 0000000..d408c57 --- /dev/null +++ b/disco/gateway/client.py @@ -0,0 +1,121 @@ +import websocket +import gevent +import json + +from disco.gateway.packets import ( + Packet, DispatchPacket, HeartbeatPacket, ReconnectPacket, InvalidSessionPacket, HelloPacket, HeartbeatAckPacket, + ResumePacket, IdentifyPacket) +from disco.util.logging import LoggingClass + +GATEWAY_VERSION = 6 + + +def log_error(log, msg, w): + def _f(*args, **kwargs): + try: + return w(*args, **kwargs) + except: + log.exception(msg) + raise + return _f + + +class GatewayClient(LoggingClass): + def __init__(self, client): + super(GatewayClient, self).__init__() + self.client = client + + # Websocket connection + self.ws = None + + # State + self.seq = 0 + self.session_id = None + + # Cached gateway URL + self._cached_gateway_url = None + + # Heartbeat + self._heartbeat_task = None + + self._fatal_error_promise = gevent.event.AsyncResult() + + def send(self, packet): + self.ws.send(json.dumps({ + 'op': int(packet.OP), + 'd': packet.to_dict(), + })) + + def heartbeat_task(self, interval): + while True: + self.send(HeartbeatPacket(data=self.seq)) + gevent.sleep(interval / 1000) + + def handle_hello(self, packet): + self.log.info('Recieved HELLO, starting heartbeater...') + self._heartbeat_task = gevent.spawn(self.heartbeat_task, packet.heartbeat_interval) + + def connect(self): + if not self._cached_gateway_url: + self._cached_gateway_url = self.client.api.gateway(version=GATEWAY_VERSION, encoding='json') + + self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url) + self.ws = websocket.WebSocketApp( + self._cached_gateway_url, + on_message=log_error(self.log, 'Error in on_message:', self.on_message), + on_error=log_error(self.log, 'Error in on_error:', self.on_error), + on_open=log_error(self.log, 'Error in on_open:', self.on_open), + on_close=log_error(self.log, 'Error in on_close:', self.on_close), + ) + + def on_message(self, ws, msg): + # TODO: ZLIB + + try: + packet = Packet.load_json(json.loads(msg)) + if packet.seq and packet.seq > self.seq: + self.seq = packet.seq + except: + self.log.exception('Failed to load dispatch:') + return + + if isinstance(packet, DispatchPacket): + self.handle_dispatch(packet) + elif isinstance(packet, HeartbeatPacket): + self.handle_heartbeat(packet) + elif isinstance(packet, ReconnectPacket): + self.handle_reconnect(packet) + elif isinstance(packet, InvalidSessionPacket): + self.handle_invalid_session(packet) + elif isinstance(packet, HelloPacket): + self.handle_hello(packet) + elif isinstance(packet, HeartbeatAckPacket): + self.handle_heartbeat_ack(packet) + else: + raise Exception('Unknown packet: {}'.format(packet)) + + def on_error(self, ws, error): + print 'error', error + + def on_open(self, ws): + print 'on open' + if self.seq and self.session_id: + self.send(ResumePacket(seq=self.seq, session_id=self.session_id, token=self.client.token)) + else: + self.send(IdentifyPacket( + token=self.client.token, + compress=True, + large_threshold=250, + shard=[self.client.sharding['number'], self.client.sharding['total']])) + + def on_close(self, ws): + print 'close' + + def run(self): + self.connect() + + # Spawn a thread to run the connection loop forever + gevent.spawn(self.ws.run_forever) + + # Wait for a fatal error + self._fatal_error_promise.get() diff --git a/disco/gateway/packets.py b/disco/gateway/packets.py new file mode 100644 index 0000000..b4dfdd0 --- /dev/null +++ b/disco/gateway/packets.py @@ -0,0 +1,116 @@ +from holster.enum import Enum + +from disco.util.oop import TypedClass + +OPCode = Enum( + DISPATCH=0, + HEARTBEAT=1, + IDENTIFY=2, + STATUS_UPDATE=3, + VOICE_STATE_UPDATE=4, + VOICE_SERVER_PING=5, + RESUME=6, + RECONNECT=7, + REQUEST_GUILD_MEMBERS=8, + INVALID_SESSION=9, + HELLO=10, + HEARTBEAT_ACK=11, + GUILD_SYNC=12, +) + + +class Packet(TypedClass): + @classmethod + def load_json(cls, obj): + if not obj['op']: + raise Exception('Packet struct missing op key: {}'.format(obj)) + + cls = PACKETS.get(obj['op']) + + if not cls: + raise Exception('Unknown OPCode: {}'.format(obj['op'])) + + obj.update(obj['d']) + del obj['d'] + inst = cls.from_dict(obj) + inst.seq = obj['s'] + return inst + + +class DispatchPacket(Packet): + OP = OPCode.DISPATCH + + PARAMS = { + ('d', 'data'): {}, + ('t', 'event'): str, + } + + +class HeartbeatPacket(Packet): + OP = OPCode.HEARTBEAT + + PARAMS = { + ('d', 'data'): (int, ), + } + + +class IdentifyPacket(Packet): + OP = OPCode.IDENTIFY + + PARAMS = { + 'token': str, + 'compress': bool, + 'large_threshold': int, + 'shard': [int], + 'properties': 'properties' + } + + @property + def properties(self): + return { + '$os': 'linux', + '$browser': 'disco', + '$device': 'disco', + '$referrer': '', + } + + +class ResumePacket(Packet): + OP = OPCode.RESUME + + PARAMS = { + 'token': str, + 'session_id': str, + 'seq': int, + } + + +class ReconnectPacket(Packet): + OP = OPCode.RECONNECT + + +class InvalidSessionPacket(Packet): + OP = OPCode.INVALID_SESSION + + +class HelloPacket(Packet): + OP = OPCode.HELLO + + PARAMS = { + 'heartbeat_interval': int, + '_trace': [str], + } + + +class HeartbeatAckPacket(Packet): + OP = OPCode.HEARTBEAT_ACK + + +PACKETS = { + int(OPCode.DISPATCH): DispatchPacket, + int(OPCode.HEARTBEAT): HeartbeatPacket, + int(OPCode.RECONNECT): ReconnectPacket, + int(OPCode.INVALID_SESSION): InvalidSessionPacket, + int(OPCode.HELLO): HelloPacket, + int(OPCode.HEARTBEAT_ACK): HeartbeatAckPacket, +} diff --git a/disco/util/__init__.py b/disco/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/disco/util/logging.py b/disco/util/logging.py new file mode 100644 index 0000000..24bfe6c --- /dev/null +++ b/disco/util/logging.py @@ -0,0 +1,8 @@ +from __future__ import absolute_import + +import logging + + +class LoggingClass(object): + def __init__(self): + self.log = logging.getLogger(self.__class__.__name__) diff --git a/disco/util/oop.py b/disco/util/oop.py new file mode 100644 index 0000000..14b3509 --- /dev/null +++ b/disco/util/oop.py @@ -0,0 +1,139 @@ +import inspect + + +class TypedClassException(Exception): + pass + + +def construct_typed_class(cls, data): + obj = cls() + load_typed_class(obj, data) + return obj + + +def get_field_and_alias(field): + if isinstance(field, tuple): + return field + else: + return field, field + + +def get_optional(typ): + if isinstance(typ, tuple) and len(typ) == 1: + return True, typ[0] + return False, typ + + +def cast(typ, value): + valid = True + + # TODO: better exceptions + if isinstance(typ, list): + if typ: + typ = typ[0] + value = map(typ, value) + else: + list(value) + elif isinstance(typ, dict): + if typ: + ktyp, vtyp = typ.items()[0] + value = {ktyp(k): vtyp(v) for k, v in typ.items()} + else: + dict(value) + elif isinstance(typ, set): + if typ: + typ = list(typ)[0] + value = set(map(typ, value)) + else: + set(value) + elif isinstance(typ, str): + valid = False + elif not isinstance(value, typ): + value = typ(value) + + return valid, value + + +def load_typed_class(obj, params, data): + print obj, params, data + for field, typ in params.items(): + field, alias = get_field_and_alias(field) + + # Skipped field + if typ is None: + continue + + optional, typ = get_optional(typ) + if field not in data and not optional: + raise TypedClassException('Missing value for attribute `{}`'.format(field)) + + value = data[field] + + print field, alias, value, typ + if value is None: + if not optional: + raise TypedClassException('Non-optional attribute `{}` cannot take None'.format(field)) + else: + valid, value = cast(typ, value) + if not valid: + continue + + setattr(obj, alias, value) + + +def dump_typed_class(obj, params): + data = {} + + for field, typ in params.items(): + field, alias = get_field_and_alias(field) + + value = getattr(obj, alias, None) + + if typ is None: + data[field] = typ + continue + + optional, typ = get_optional(typ) + if not value and not optional: + raise TypedClassException('Missing value for attribute `{}`'.format(field)) + + _, value = cast(typ, value) + data[field] = value + + return data + + +def get_params(obj): + assert(issubclass(obj.__class__, TypedClass)) + + if not hasattr(obj.__class__, '_cached_oop_params'): + base = {} + for cls in reversed(inspect.getmro(obj.__class__)): + base.update(getattr(cls, 'PARAMS', {})) + obj.__class__._cached_oop_params = base + return obj.__class__._cached_oop_params + + +def load_typed_instance(obj, data): + return load_typed_class(obj, get_params(obj), data) + + +class TypedClass(object): + def __init__(self, **kwargs): + # TODO: validate + self.__dict__.update(kwargs) + + @classmethod + def from_dict(cls, data): + self = cls() + load_typed_instance(self, data) + return self + + def to_dict(self): + return dump_typed_class(self, get_params(self)) + + +def require_implementation(attr): + def _f(self, *args, **kwargs): + raise NotImplementedError('{} must implement method {}', self.__class__.__name, attr) + return _f diff --git a/disco/util/token.py b/disco/util/token.py new file mode 100644 index 0000000..c48beca --- /dev/null +++ b/disco/util/token.py @@ -0,0 +1,10 @@ +import re + +TOKEN_RE = re.compile(r'M\w{23}\.[\w-]{6}\..{27}') + + +def is_valid_token(token): + """ + Validates a Discord authentication token, returning true if valid + """ + return bool(TOKEN_RE.match(token))