commit
c12aa5344f
14 changed files with 508 additions and 0 deletions
@ -0,0 +1 @@ |
|||
VERSION = '0.0.1' |
@ -0,0 +1,15 @@ |
|||
from disco.api.http import Routes, HTTPClient |
|||
|
|||
from disco.util.logging import LoggingClass |
|||
|
|||
|
|||
class APIClient(LoggingClass): |
|||
def __init__(self, client): |
|||
super(APIClient, self).__init__() |
|||
|
|||
self.client = client |
|||
self.http = HTTPClient(self.client.token) |
|||
|
|||
def gateway(self, version, encoding): |
|||
r = self.http(Routes.GATEWAY_GET) |
|||
return r['url'] + '?v={}&encoding={}'.format(version, encoding) |
@ -0,0 +1,50 @@ |
|||
import requests |
|||
|
|||
from holster.enum import Enum |
|||
|
|||
HTTPMethod = Enum( |
|||
GET='GET', |
|||
POST='POST', |
|||
PUT='PUT', |
|||
PATCH='PATCH', |
|||
DELETE='DELETE', |
|||
) |
|||
|
|||
|
|||
class Routes(object): |
|||
USERS_ME_GET = (HTTPMethod.GET, '/users/@me') |
|||
USERS_ME_PATCH = (HTTPMethod.PATCH, '/users/@me') |
|||
|
|||
GATEWAY_GET = (HTTPMethod.GET, '/gateway') |
|||
|
|||
|
|||
class APIException(Exception): |
|||
def __init__(self, obj): |
|||
self.code = obj['code'] |
|||
self.msg = obj['msg'] |
|||
|
|||
super(APIException, self).__init__(self.msg) |
|||
|
|||
|
|||
class HTTPClient(object): |
|||
BASE_URL = 'https://discordapp.com/api' |
|||
|
|||
def __init__(self, token): |
|||
self.headers = { |
|||
'Authorization': 'Bot ' + token, |
|||
} |
|||
|
|||
def __call__(self, route, *args, **kwargs): |
|||
method, url = route |
|||
|
|||
r = requests.request(str(method), self.BASE_URL + url, *args, **kwargs) |
|||
|
|||
try: |
|||
r.raise_for_status() |
|||
except: |
|||
# TODO: rate limits |
|||
# TODO: check json |
|||
raise APIException(r.json()) |
|||
|
|||
# TODO: check json |
|||
return r.json() |
@ -0,0 +1,26 @@ |
|||
import logging |
|||
import argparse |
|||
|
|||
from gevent import monkey |
|||
|
|||
parser = argparse.ArgumentParser() |
|||
parser.add_argument('--token', help='Bot Authentication Token', required=True) |
|||
|
|||
logging.basicConfig(level=logging.INFO) |
|||
|
|||
|
|||
def main(): |
|||
monkey.patch_all() |
|||
args = parser.parse_args() |
|||
|
|||
from disco.util.token import is_valid_token |
|||
|
|||
if not is_valid_token(args.token): |
|||
print 'Invalid token passed' |
|||
return |
|||
|
|||
from disco.client import DiscoClient |
|||
DiscoClient(args.token).run_forever() |
|||
|
|||
if __name__ == '__main__': |
|||
main() |
@ -0,0 +1,22 @@ |
|||
import logging |
|||
|
|||
from disco.api.client import APIClient |
|||
from disco.gateway.client import GatewayClient |
|||
|
|||
log = logging.getLogger(__name__) |
|||
|
|||
|
|||
class DiscoClient(object): |
|||
def __init__(self, token, sharding=None): |
|||
self.log = log |
|||
self.token = token |
|||
self.sharding = sharding or {'number': 0, 'total': 1} |
|||
|
|||
self.api = APIClient(self) |
|||
self.gw = GatewayClient(self) |
|||
|
|||
def run(self): |
|||
return self.gw.run() |
|||
|
|||
def run_forever(self): |
|||
return self.gw.run().join() |
@ -0,0 +1,121 @@ |
|||
import websocket |
|||
import gevent |
|||
import json |
|||
|
|||
from disco.gateway.packets import ( |
|||
Packet, DispatchPacket, HeartbeatPacket, ReconnectPacket, InvalidSessionPacket, HelloPacket, HeartbeatAckPacket, |
|||
ResumePacket, IdentifyPacket) |
|||
from disco.util.logging import LoggingClass |
|||
|
|||
GATEWAY_VERSION = 6 |
|||
|
|||
|
|||
def log_error(log, msg, w): |
|||
def _f(*args, **kwargs): |
|||
try: |
|||
return w(*args, **kwargs) |
|||
except: |
|||
log.exception(msg) |
|||
raise |
|||
return _f |
|||
|
|||
|
|||
class GatewayClient(LoggingClass): |
|||
def __init__(self, client): |
|||
super(GatewayClient, self).__init__() |
|||
self.client = client |
|||
|
|||
# Websocket connection |
|||
self.ws = None |
|||
|
|||
# State |
|||
self.seq = 0 |
|||
self.session_id = None |
|||
|
|||
# Cached gateway URL |
|||
self._cached_gateway_url = None |
|||
|
|||
# Heartbeat |
|||
self._heartbeat_task = None |
|||
|
|||
self._fatal_error_promise = gevent.event.AsyncResult() |
|||
|
|||
def send(self, packet): |
|||
self.ws.send(json.dumps({ |
|||
'op': int(packet.OP), |
|||
'd': packet.to_dict(), |
|||
})) |
|||
|
|||
def heartbeat_task(self, interval): |
|||
while True: |
|||
self.send(HeartbeatPacket(data=self.seq)) |
|||
gevent.sleep(interval / 1000) |
|||
|
|||
def handle_hello(self, packet): |
|||
self.log.info('Recieved HELLO, starting heartbeater...') |
|||
self._heartbeat_task = gevent.spawn(self.heartbeat_task, packet.heartbeat_interval) |
|||
|
|||
def connect(self): |
|||
if not self._cached_gateway_url: |
|||
self._cached_gateway_url = self.client.api.gateway(version=GATEWAY_VERSION, encoding='json') |
|||
|
|||
self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url) |
|||
self.ws = websocket.WebSocketApp( |
|||
self._cached_gateway_url, |
|||
on_message=log_error(self.log, 'Error in on_message:', self.on_message), |
|||
on_error=log_error(self.log, 'Error in on_error:', self.on_error), |
|||
on_open=log_error(self.log, 'Error in on_open:', self.on_open), |
|||
on_close=log_error(self.log, 'Error in on_close:', self.on_close), |
|||
) |
|||
|
|||
def on_message(self, ws, msg): |
|||
# TODO: ZLIB |
|||
|
|||
try: |
|||
packet = Packet.load_json(json.loads(msg)) |
|||
if packet.seq and packet.seq > self.seq: |
|||
self.seq = packet.seq |
|||
except: |
|||
self.log.exception('Failed to load dispatch:') |
|||
return |
|||
|
|||
if isinstance(packet, DispatchPacket): |
|||
self.handle_dispatch(packet) |
|||
elif isinstance(packet, HeartbeatPacket): |
|||
self.handle_heartbeat(packet) |
|||
elif isinstance(packet, ReconnectPacket): |
|||
self.handle_reconnect(packet) |
|||
elif isinstance(packet, InvalidSessionPacket): |
|||
self.handle_invalid_session(packet) |
|||
elif isinstance(packet, HelloPacket): |
|||
self.handle_hello(packet) |
|||
elif isinstance(packet, HeartbeatAckPacket): |
|||
self.handle_heartbeat_ack(packet) |
|||
else: |
|||
raise Exception('Unknown packet: {}'.format(packet)) |
|||
|
|||
def on_error(self, ws, error): |
|||
print 'error', error |
|||
|
|||
def on_open(self, ws): |
|||
print 'on open' |
|||
if self.seq and self.session_id: |
|||
self.send(ResumePacket(seq=self.seq, session_id=self.session_id, token=self.client.token)) |
|||
else: |
|||
self.send(IdentifyPacket( |
|||
token=self.client.token, |
|||
compress=True, |
|||
large_threshold=250, |
|||
shard=[self.client.sharding['number'], self.client.sharding['total']])) |
|||
|
|||
def on_close(self, ws): |
|||
print 'close' |
|||
|
|||
def run(self): |
|||
self.connect() |
|||
|
|||
# Spawn a thread to run the connection loop forever |
|||
gevent.spawn(self.ws.run_forever) |
|||
|
|||
# Wait for a fatal error |
|||
self._fatal_error_promise.get() |
@ -0,0 +1,116 @@ |
|||
from holster.enum import Enum |
|||
|
|||
from disco.util.oop import TypedClass |
|||
|
|||
OPCode = Enum( |
|||
DISPATCH=0, |
|||
HEARTBEAT=1, |
|||
IDENTIFY=2, |
|||
STATUS_UPDATE=3, |
|||
VOICE_STATE_UPDATE=4, |
|||
VOICE_SERVER_PING=5, |
|||
RESUME=6, |
|||
RECONNECT=7, |
|||
REQUEST_GUILD_MEMBERS=8, |
|||
INVALID_SESSION=9, |
|||
HELLO=10, |
|||
HEARTBEAT_ACK=11, |
|||
GUILD_SYNC=12, |
|||
) |
|||
|
|||
|
|||
class Packet(TypedClass): |
|||
@classmethod |
|||
def load_json(cls, obj): |
|||
if not obj['op']: |
|||
raise Exception('Packet struct missing op key: {}'.format(obj)) |
|||
|
|||
cls = PACKETS.get(obj['op']) |
|||
|
|||
if not cls: |
|||
raise Exception('Unknown OPCode: {}'.format(obj['op'])) |
|||
|
|||
obj.update(obj['d']) |
|||
del obj['d'] |
|||
inst = cls.from_dict(obj) |
|||
inst.seq = obj['s'] |
|||
return inst |
|||
|
|||
|
|||
class DispatchPacket(Packet): |
|||
OP = OPCode.DISPATCH |
|||
|
|||
PARAMS = { |
|||
('d', 'data'): {}, |
|||
('t', 'event'): str, |
|||
} |
|||
|
|||
|
|||
class HeartbeatPacket(Packet): |
|||
OP = OPCode.HEARTBEAT |
|||
|
|||
PARAMS = { |
|||
('d', 'data'): (int, ), |
|||
} |
|||
|
|||
|
|||
class IdentifyPacket(Packet): |
|||
OP = OPCode.IDENTIFY |
|||
|
|||
PARAMS = { |
|||
'token': str, |
|||
'compress': bool, |
|||
'large_threshold': int, |
|||
'shard': [int], |
|||
'properties': 'properties' |
|||
} |
|||
|
|||
@property |
|||
def properties(self): |
|||
return { |
|||
'$os': 'linux', |
|||
'$browser': 'disco', |
|||
'$device': 'disco', |
|||
'$referrer': '', |
|||
} |
|||
|
|||
|
|||
class ResumePacket(Packet): |
|||
OP = OPCode.RESUME |
|||
|
|||
PARAMS = { |
|||
'token': str, |
|||
'session_id': str, |
|||
'seq': int, |
|||
} |
|||
|
|||
|
|||
class ReconnectPacket(Packet): |
|||
OP = OPCode.RECONNECT |
|||
|
|||
|
|||
class InvalidSessionPacket(Packet): |
|||
OP = OPCode.INVALID_SESSION |
|||
|
|||
|
|||
class HelloPacket(Packet): |
|||
OP = OPCode.HELLO |
|||
|
|||
PARAMS = { |
|||
'heartbeat_interval': int, |
|||
'_trace': [str], |
|||
} |
|||
|
|||
|
|||
class HeartbeatAckPacket(Packet): |
|||
OP = OPCode.HEARTBEAT_ACK |
|||
|
|||
|
|||
PACKETS = { |
|||
int(OPCode.DISPATCH): DispatchPacket, |
|||
int(OPCode.HEARTBEAT): HeartbeatPacket, |
|||
int(OPCode.RECONNECT): ReconnectPacket, |
|||
int(OPCode.INVALID_SESSION): InvalidSessionPacket, |
|||
int(OPCode.HELLO): HelloPacket, |
|||
int(OPCode.HEARTBEAT_ACK): HeartbeatAckPacket, |
|||
} |
@ -0,0 +1,8 @@ |
|||
from __future__ import absolute_import |
|||
|
|||
import logging |
|||
|
|||
|
|||
class LoggingClass(object): |
|||
def __init__(self): |
|||
self.log = logging.getLogger(self.__class__.__name__) |
@ -0,0 +1,139 @@ |
|||
import inspect |
|||
|
|||
|
|||
class TypedClassException(Exception): |
|||
pass |
|||
|
|||
|
|||
def construct_typed_class(cls, data): |
|||
obj = cls() |
|||
load_typed_class(obj, data) |
|||
return obj |
|||
|
|||
|
|||
def get_field_and_alias(field): |
|||
if isinstance(field, tuple): |
|||
return field |
|||
else: |
|||
return field, field |
|||
|
|||
|
|||
def get_optional(typ): |
|||
if isinstance(typ, tuple) and len(typ) == 1: |
|||
return True, typ[0] |
|||
return False, typ |
|||
|
|||
|
|||
def cast(typ, value): |
|||
valid = True |
|||
|
|||
# TODO: better exceptions |
|||
if isinstance(typ, list): |
|||
if typ: |
|||
typ = typ[0] |
|||
value = map(typ, value) |
|||
else: |
|||
list(value) |
|||
elif isinstance(typ, dict): |
|||
if typ: |
|||
ktyp, vtyp = typ.items()[0] |
|||
value = {ktyp(k): vtyp(v) for k, v in typ.items()} |
|||
else: |
|||
dict(value) |
|||
elif isinstance(typ, set): |
|||
if typ: |
|||
typ = list(typ)[0] |
|||
value = set(map(typ, value)) |
|||
else: |
|||
set(value) |
|||
elif isinstance(typ, str): |
|||
valid = False |
|||
elif not isinstance(value, typ): |
|||
value = typ(value) |
|||
|
|||
return valid, value |
|||
|
|||
|
|||
def load_typed_class(obj, params, data): |
|||
print obj, params, data |
|||
for field, typ in params.items(): |
|||
field, alias = get_field_and_alias(field) |
|||
|
|||
# Skipped field |
|||
if typ is None: |
|||
continue |
|||
|
|||
optional, typ = get_optional(typ) |
|||
if field not in data and not optional: |
|||
raise TypedClassException('Missing value for attribute `{}`'.format(field)) |
|||
|
|||
value = data[field] |
|||
|
|||
print field, alias, value, typ |
|||
if value is None: |
|||
if not optional: |
|||
raise TypedClassException('Non-optional attribute `{}` cannot take None'.format(field)) |
|||
else: |
|||
valid, value = cast(typ, value) |
|||
if not valid: |
|||
continue |
|||
|
|||
setattr(obj, alias, value) |
|||
|
|||
|
|||
def dump_typed_class(obj, params): |
|||
data = {} |
|||
|
|||
for field, typ in params.items(): |
|||
field, alias = get_field_and_alias(field) |
|||
|
|||
value = getattr(obj, alias, None) |
|||
|
|||
if typ is None: |
|||
data[field] = typ |
|||
continue |
|||
|
|||
optional, typ = get_optional(typ) |
|||
if not value and not optional: |
|||
raise TypedClassException('Missing value for attribute `{}`'.format(field)) |
|||
|
|||
_, value = cast(typ, value) |
|||
data[field] = value |
|||
|
|||
return data |
|||
|
|||
|
|||
def get_params(obj): |
|||
assert(issubclass(obj.__class__, TypedClass)) |
|||
|
|||
if not hasattr(obj.__class__, '_cached_oop_params'): |
|||
base = {} |
|||
for cls in reversed(inspect.getmro(obj.__class__)): |
|||
base.update(getattr(cls, 'PARAMS', {})) |
|||
obj.__class__._cached_oop_params = base |
|||
return obj.__class__._cached_oop_params |
|||
|
|||
|
|||
def load_typed_instance(obj, data): |
|||
return load_typed_class(obj, get_params(obj), data) |
|||
|
|||
|
|||
class TypedClass(object): |
|||
def __init__(self, **kwargs): |
|||
# TODO: validate |
|||
self.__dict__.update(kwargs) |
|||
|
|||
@classmethod |
|||
def from_dict(cls, data): |
|||
self = cls() |
|||
load_typed_instance(self, data) |
|||
return self |
|||
|
|||
def to_dict(self): |
|||
return dump_typed_class(self, get_params(self)) |
|||
|
|||
|
|||
def require_implementation(attr): |
|||
def _f(self, *args, **kwargs): |
|||
raise NotImplementedError('{} must implement method {}', self.__class__.__name, attr) |
|||
return _f |
@ -0,0 +1,10 @@ |
|||
import re |
|||
|
|||
TOKEN_RE = re.compile(r'M\w{23}\.[\w-]{6}\..{27}') |
|||
|
|||
|
|||
def is_valid_token(token): |
|||
""" |
|||
Validates a Discord authentication token, returning true if valid |
|||
""" |
|||
return bool(TOKEN_RE.match(token)) |
Loading…
Reference in new issue