From 13ee9eae093f15c04dbeb2e5e5be95dd6801ddfd Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 24 Sep 2016 22:25:07 -0500 Subject: [PATCH] Refactoring, start work on voice support --- disco/bot/command.py | 4 ++ disco/client.py | 1 + disco/gateway/client.py | 96 +++++++++++++-------------- disco/gateway/packets.py | 85 ------------------------ disco/state.py | 25 ++++--- disco/types/base.py | 7 ++ disco/types/channel.py | 14 +++- disco/types/guild.py | 36 +++++++++- disco/types/user.py | 3 + disco/types/voice.py | 24 ++++++- disco/util/__init__.py | 11 ++++ disco/util/cache.py | 11 +++- disco/util/json.py | 11 ++++ disco/util/oop.py | 139 --------------------------------------- disco/util/websocket.py | 28 ++++++++ disco/voice/__init__.py | 0 disco/voice/client.py | 119 +++++++++++++++++++++++++++++++++ disco/voice/packets.py | 10 +++ examples/basic_plugin.py | 11 ++++ 19 files changed, 342 insertions(+), 293 deletions(-) create mode 100644 disco/util/json.py delete mode 100644 disco/util/oop.py create mode 100644 disco/util/websocket.py create mode 100644 disco/voice/__init__.py create mode 100644 disco/voice/client.py create mode 100644 disco/voice/packets.py diff --git a/disco/bot/command.py b/disco/bot/command.py index 87fb4f9..1702f83 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -15,6 +15,10 @@ class CommandEvent(object): self.name = self.match.group(1) self.args = self.match.group(2).strip().split(' ') + @cached_property + def member(self): + return self.guild.get_member(self.actor) + @property def channel(self): return self.msg.channel diff --git a/disco/client.py b/disco/client.py index 0d6cc26..e7872c8 100644 --- a/disco/client.py +++ b/disco/client.py @@ -17,6 +17,7 @@ class DiscoClient(object): self.sharding = sharding or {'number': 0, 'total': 1} self.events = Emitter(gevent.spawn) + self.packets = Emitter(gevent.spawn) self.state = State(self) self.api = APIClient(self) diff --git a/disco/gateway/client.py b/disco/gateway/client.py index 7ed5f6c..90b167d 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -1,27 +1,17 @@ -import websocket import gevent -import json import zlib -import six import ssl -from disco.gateway.packets import OPCode, HeartbeatPacket, ResumePacket, IdentifyPacket +from disco.gateway.packets import OPCode from disco.gateway.events import GatewayEvent +from disco.util.json import loads, dumps +from disco.util.websocket import Websocket from disco.util.logging import LoggingClass GATEWAY_VERSION = 6 TEN_MEGABYTES = 10490000 -# Hack to get websocket close information -def websocket_get_close_args_override(data): - if data and len(data) >= 2: - code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2]) - reason = data[2:].decode('utf-8') - return [code, reason] - return [None, None] - - class GatewayClient(LoggingClass): MAX_RECONNECTS = 5 @@ -29,7 +19,19 @@ class GatewayClient(LoggingClass): super(GatewayClient, self).__init__() self.client = client - self.client.events.on('Ready', self.on_ready) + self.events = client.events + self.packets = client.packets + + # Create emitter and bind to gateway payloads + self.packets.on(OPCode.DISPATCH, self.handle_dispatch) + self.packets.on(OPCode.HEARTBEAT, self.handle_heartbeat) + self.packets.on(OPCode.RECONNECT, self.handle_reconnect) + self.packets.on(OPCode.INVALID_SESSION, self.handle_invalid_session) + self.packets.on(OPCode.HELLO, self.handle_hello) + self.packets.on(OPCode.HEARTBEAT_ACK, self.handle_heartbeat_ack) + + # Bind to ready payload + self.events.on('Ready', self.on_ready) # Websocket connection self.ws = None @@ -46,15 +48,15 @@ class GatewayClient(LoggingClass): # Heartbeat self._heartbeat_task = None - def send(self, packet): - self.ws.send(json.dumps({ - 'op': int(packet.OP), - 'd': packet.to_dict(), + def send(self, op, data): + self.ws.send(dumps({ + 'op': op.value, + 'd': data, })) def heartbeat_task(self, interval): while True: - self.send(HeartbeatPacket(data=self.seq)) + self.send(OPCode.HEARTBEAT, self.seq) gevent.sleep(interval / 1000) def handle_dispatch(self, packet): @@ -63,7 +65,7 @@ class GatewayClient(LoggingClass): self.client.events.emit(obj.__class__.__name__, obj) def handle_heartbeat(self, packet): - self.send(HeartbeatPacket(data=self.seq)) + self.send(OPCode.HEARTBEAT, self.seq) def handle_reconnect(self, packet): self.log.warning('Received RECONNECT request, forcing a fresh reconnect') @@ -92,14 +94,13 @@ class GatewayClient(LoggingClass): self._cached_gateway_url = self.client.api.gateway(version=GATEWAY_VERSION, encoding='json') self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url) - self.ws = websocket.WebSocketApp( + self.ws = Websocket( self._cached_gateway_url, - on_message=self.log_on_error('Error in on_message:', self.on_message), - on_error=self.log_on_error('Error in on_error:', self.on_error), - on_open=self.log_on_error('Error in on_open:', self.on_open), - on_close=self.log_on_error('Error in on_close:', self.on_close), + on_message=self.on_message, + on_error=self.on_error, + on_open=self.on_open, + on_close=self.on_close, ) - self.ws._get_close_args = websocket_get_close_args_override self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) def on_message(self, ws, msg): @@ -108,29 +109,17 @@ class GatewayClient(LoggingClass): msg = zlib.decompress(msg, 15, TEN_MEGABYTES) try: - data = json.loads(msg) + data = loads(msg) except: - self.log.exception('Failed to load dispatch:') + self.log.exception('Failed to parse gateway message: ') return # Update sequence if data['s'] and data['s'] > self.seq: self.seq = data['s'] - if data['op'] == OPCode.DISPATCH: - self.handle_dispatch(data) - elif data['op'] == OPCode.HEARTBEAT: - self.handle_heartbeat(data) - elif data['op'] == OPCode.RECONNECT: - self.handle_reconnect(data) - elif data['op'] == OPCode.INVALID_SESSION: - self.handle_invalid_session(data) - elif data['op'] == OPCode.HELLO: - self.handle_hello(data) - elif data['op'] == OPCode.HEARTBEAT_ACK: - self.handle_heartbeat_ack(data) - else: - raise Exception('Unknown packet: {}'.format(data['op'])) + # Emit packet + self.packets.emit(OPCode[data['op']], data) def on_error(self, ws, error): if isinstance(error, KeyboardInterrupt): @@ -140,14 +129,25 @@ class GatewayClient(LoggingClass): def on_open(self, ws): if self.seq and self.session_id: self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq) - self.send(ResumePacket(seq=self.seq, session_id=self.session_id, token=self.client.token)) + self.send(OPCode.RESUME, { + 'token': self.client.token, + 'session_id': self.session_id, + 'seq': self.seq + }) else: self.log.info('WS Opened: sending identify payload') - self.send(IdentifyPacket( - token=self.client.token, - compress=True, - large_threshold=250, - shard=[self.client.sharding['number'], self.client.sharding['total']])) + self.send(OPCode.IDENTIFY, { + 'token': self.client.token, + 'compress': True, + 'large_threshold': 250, + 'shard': [self.client.sharding['number'], self.client.sharding['total']], + 'properties': { + '$os': 'linux', + '$browser': 'disco', + '$device': 'disco', + '$referrer': '', + } + }) def on_close(self, ws, code, reason): if self.shutting_down: diff --git a/disco/gateway/packets.py b/disco/gateway/packets.py index 8a8e818..5ca9793 100644 --- a/disco/gateway/packets.py +++ b/disco/gateway/packets.py @@ -1,7 +1,5 @@ from holster.enum import Enum -from disco.util.oop import TypedClass - OPCode = Enum( DISPATCH=0, HEARTBEAT=1, @@ -17,86 +15,3 @@ OPCode = Enum( HEARTBEAT_ACK=11, GUILD_SYNC=12, ) - - -class Packet(TypedClass): - pass - - -class DispatchPacket(Packet): - OP = OPCode.DISPATCH - - PARAMS = { - ('d', 'data'): {}, - ('t', 'event'): str, - } - - -class HeartbeatPacket(Packet): - OP = OPCode.HEARTBEAT - - PARAMS = { - ('d', 'data'): (int, ), - } - - -class IdentifyPacket(Packet): - OP = OPCode.IDENTIFY - - PARAMS = { - 'token': str, - 'compress': bool, - 'large_threshold': int, - 'shard': [int], - 'properties': 'properties' - } - - @property - def properties(self): - return { - '$os': 'linux', - '$browser': 'disco', - '$device': 'disco', - '$referrer': '', - } - - -class ResumePacket(Packet): - OP = OPCode.RESUME - - PARAMS = { - 'token': str, - 'session_id': str, - 'seq': int, - } - - -class ReconnectPacket(Packet): - OP = OPCode.RECONNECT - - -class InvalidSessionPacket(Packet): - OP = OPCode.INVALID_SESSION - - -class HelloPacket(Packet): - OP = OPCode.HELLO - - PARAMS = { - 'heartbeat_interval': int, - '_trace': [str], - } - - -class HeartbeatAckPacket(Packet): - OP = OPCode.HEARTBEAT_ACK - - -PACKETS = { - int(OPCode.DISPATCH): DispatchPacket, - int(OPCode.HEARTBEAT): HeartbeatPacket, - int(OPCode.RECONNECT): ReconnectPacket, - int(OPCode.INVALID_SESSION): InvalidSessionPacket, - int(OPCode.HELLO): HelloPacket, - int(OPCode.HEARTBEAT_ACK): HeartbeatAckPacket, -} diff --git a/disco/state.py b/disco/state.py index 750853f..8326edf 100644 --- a/disco/state.py +++ b/disco/state.py @@ -23,10 +23,11 @@ class State(object): self.dms = {} self.guilds = {} self.channels = WeakValueDictionary() + self.users = WeakValueDictionary() self.client.events.on('Ready', self.on_ready) - self.messages_stack = defaultdict(lambda: deque(maxlen=self.config.track_messages_size)) + self.messages = defaultdict(lambda: deque(maxlen=self.config.track_messages_size)) if self.config.track_messages: self.client.events.on('MessageCreate', self.on_message_create) self.client.events.on('MessageDelete', self.on_message_delete) @@ -45,15 +46,15 @@ class State(object): self.me = event.user def on_message_create(self, event): - self.messages_stack[event.message.channel_id].append( + self.messages[event.message.channel_id].append( StackMessage(event.message.id, event.message.channel_id, event.message.author.id)) def on_message_update(self, event): message, cid = event.message, event.message.channel_id - if cid not in self.messages_stack: + if cid not in self.messages: return - sm = next((i for i in self.messages_stack[cid] if i.id == message.id), None) + sm = next((i for i in self.messages[cid] if i.id == message.id), None) if not sm: return @@ -62,21 +63,21 @@ class State(object): sm.author_id = message.author.id def on_message_delete(self, event): - if event.channel_id not in self.messages_stack: + if event.channel_id not in self.messages: return - sm = next((i for i in self.messages_stack[event.channel_id] if i.id == event.id), None) + sm = next((i for i in self.messages[event.channel_id] if i.id == event.id), None) if not sm: return - self.messages_stack[event.channel_id].remove(sm) + self.messages[event.channel_id].remove(sm) def on_guild_create(self, event): self.guilds[event.guild.id] = event.guild self.channels.update(event.guild.channels) def on_guild_update(self, event): - self.guilds[event.guild.id] = event.guild + self.guilds[event.guild.id].update(event.guild) def on_guild_delete(self, event): if event.guild_id in self.guilds: @@ -92,12 +93,8 @@ class State(object): self.channels[event.channel.id] = event.channel def on_channel_update(self, event): - if event.channel.is_guild and event.channel.guild_id in self.guilds: - self.guilds[event.channel.id] = event.channel - self.channels[event.channel.id] = event.channel - elif event.channel.is_dm: - self.dms[event.channel.id] = event.channel - self.channels[event.channel.id] = event.channel + if event.channel.id in self.channels: + self.channels[event.channel.id].update(event.channel) def on_channel_delete(self, event): if event.channel.is_guild and event.channel.guild_id in self.guilds: diff --git a/disco/types/base.py b/disco/types/base.py index c6fb366..6e945eb 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -5,6 +5,12 @@ from disco.util import skema_find_recursive_by_type class BaseType(skema.Model): + def on_create(self): + pass + + def update(self, other): + self.__dict__.update(other.__dict__) + @classmethod def create(cls, client, data): obj = cls(data) @@ -16,6 +22,7 @@ class BaseType(skema.Model): item.client = client obj.client = client + obj.on_create() return obj @classmethod diff --git a/disco/types/channel.py b/disco/types/channel.py index 2e05e3e..15b1d0f 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -6,6 +6,7 @@ from disco.util.cache import cached_property from disco.util.types import ListToDictType from disco.types.base import BaseType from disco.types.user import User +from disco.voice.client import VoiceClient ChannelType = Enum( @@ -52,11 +53,15 @@ class Channel(BaseType): def is_dm(self): return self.type in (ChannelType.DM, ChannelType.GROUP_DM) + @property + def is_voice(self): + return self.type in (ChannelType.GUILD_VOICE, ChannelType.GROUP_DM) + @property def last_message_id(self): - if self.id not in self.client.state.messages_stack: + if self.id not in self.client.state.messages: return self._last_message_id - return self.client.state.messages_stack[self.id][-1].id + return self.client.state.messages[self.id][-1].id @property def messages(self): @@ -78,6 +83,11 @@ class Channel(BaseType): def send_message(self, content, nonce=None, tts=False): return self.client.api.channels_messages_create(self.id, content, nonce, tts) + def connect(self, *args, **kwargs): + vc = VoiceClient(self) + vc.connect(*args, **kwargs) + return vc + class MessageIterator(object): Direction = Enum('UP', 'DOWN') diff --git a/disco/types/guild.py b/disco/types/guild.py index 09c6c55..bbb4570 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -1,5 +1,7 @@ import skema +from disco.api.http import APIException +from disco.util import to_snowflake from disco.types.base import BaseType from disco.util.types import PreHookType, ListToDictType from disco.types.user import User @@ -32,6 +34,15 @@ class GuildMember(BaseType): joined_at = PreHookType(lambda k: k[:-6], skema.DateTimeType()) roles = skema.ListType(skema.SnowflakeType()) + def get_voice_state(self): + return self.guild.get_voice_state(self) + + def kick(self): + self.client.api.guilds_members_kick(self.guild.id, self.user.id) + + def ban(self, delete_message_days=0): + self.client.api.guilds_bans_create(self.guild.id, self.user.id, delete_message_days) + @property def id(self): return self.user.id @@ -60,12 +71,33 @@ class Guild(BaseType): channels = ListToDictType('id', skema.ModelType(Channel)) roles = ListToDictType('id', skema.ModelType(Role)) emojis = ListToDictType('id', skema.ModelType(Emoji)) - voice_states = ListToDictType('id', skema.ModelType(VoiceState)) + voice_states = ListToDictType('session_id', skema.ModelType(VoiceState)) + + def get_voice_state(self, user): + user = to_snowflake(user) + + for state in self.voice_states.values(): + if state.user_id == user: + return state def get_member(self, user): - return self.members.get(user.id) + user = to_snowflake(user) + + if user not in self.members: + try: + self.members[user] = self.client.api.guilds_members_get(self.id, user) + except APIException: + pass + + return self.members.get(user) + + def validate_members(self, ctx): + if self.members: + for member in self.members.values(): + member.guild = self def validate_channels(self, ctx): if self.channels: for channel in self.channels.values(): channel.guild_id = self.id + channel.guild = self diff --git a/disco/types/user.py b/disco/types/user.py index 6b24c9e..86bff12 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -12,3 +12,6 @@ class User(BaseType): verified = skema.BooleanType(required=False) email = skema.EmailType(required=False) + + def on_create(self): + self.client.state.users[self.id] = self diff --git a/disco/types/voice.py b/disco/types/voice.py index 997a77d..843f04b 100644 --- a/disco/types/voice.py +++ b/disco/types/voice.py @@ -4,4 +4,26 @@ from disco.types.base import BaseType class VoiceState(BaseType): - id = skema.SnowflakeType() + session_id = skema.StringType() + + guild_id = skema.SnowflakeType() + channel_id = skema.SnowflakeType() + user_id = skema.SnowflakeType() + + deaf = skema.BooleanType() + mute = skema.BooleanType() + self_deaf = skema.BooleanType() + self_mute = skema.BooleanType() + suppress = skema.BooleanType() + + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + + @property + def channel(self): + return self.client.state.channels.get(self.channel_id) + + @property + def user(self): + return self.client.state.users.get(self.user_id) diff --git a/disco/util/__init__.py b/disco/util/__init__.py index 1b5f31e..d0c3a37 100644 --- a/disco/util/__init__.py +++ b/disco/util/__init__.py @@ -1,6 +1,17 @@ import skema +def to_snowflake(i): + if isinstance(i, long): + return i + elif isinstance(i, str): + return long(i) + elif hasattr(i, 'id'): + return i.id + + raise Exception('{} ({}) is not convertable to a snowflake'.format(type(i), i)) + + def _recurse(typ, field, value): result = [] diff --git a/disco/util/cache.py b/disco/util/cache.py index 5a86936..614858e 100644 --- a/disco/util/cache.py +++ b/disco/util/cache.py @@ -1,8 +1,15 @@ def cached_property(f): - def deco(self, *args, **kwargs): + def getf(self, *args, **kwargs): if not hasattr(self, '__' + f.__name__): setattr(self, '__' + f.__name__, f(self, *args, **kwargs)) return getattr(self, '__' + f.__name__) - return property(deco) + + def setf(self, value): + setattr(self, '__' + f.__name__, value) + + def delf(self): + setattr(self, '__' + f.__name__, None) + + return property(getf, setf, delf) diff --git a/disco/util/json.py b/disco/util/json.py new file mode 100644 index 0000000..01267ff --- /dev/null +++ b/disco/util/json.py @@ -0,0 +1,11 @@ +from __future__ import absolute_import + +from json import dumps + +try: + from rapidjson import loads +except ImportError: + print '[WARNING] rapidjson not installed, falling back to default Python JSON parser' + from json import loads + +__all__ = ['dumps', 'loads'] diff --git a/disco/util/oop.py b/disco/util/oop.py deleted file mode 100644 index 14b3509..0000000 --- a/disco/util/oop.py +++ /dev/null @@ -1,139 +0,0 @@ -import inspect - - -class TypedClassException(Exception): - pass - - -def construct_typed_class(cls, data): - obj = cls() - load_typed_class(obj, data) - return obj - - -def get_field_and_alias(field): - if isinstance(field, tuple): - return field - else: - return field, field - - -def get_optional(typ): - if isinstance(typ, tuple) and len(typ) == 1: - return True, typ[0] - return False, typ - - -def cast(typ, value): - valid = True - - # TODO: better exceptions - if isinstance(typ, list): - if typ: - typ = typ[0] - value = map(typ, value) - else: - list(value) - elif isinstance(typ, dict): - if typ: - ktyp, vtyp = typ.items()[0] - value = {ktyp(k): vtyp(v) for k, v in typ.items()} - else: - dict(value) - elif isinstance(typ, set): - if typ: - typ = list(typ)[0] - value = set(map(typ, value)) - else: - set(value) - elif isinstance(typ, str): - valid = False - elif not isinstance(value, typ): - value = typ(value) - - return valid, value - - -def load_typed_class(obj, params, data): - print obj, params, data - for field, typ in params.items(): - field, alias = get_field_and_alias(field) - - # Skipped field - if typ is None: - continue - - optional, typ = get_optional(typ) - if field not in data and not optional: - raise TypedClassException('Missing value for attribute `{}`'.format(field)) - - value = data[field] - - print field, alias, value, typ - if value is None: - if not optional: - raise TypedClassException('Non-optional attribute `{}` cannot take None'.format(field)) - else: - valid, value = cast(typ, value) - if not valid: - continue - - setattr(obj, alias, value) - - -def dump_typed_class(obj, params): - data = {} - - for field, typ in params.items(): - field, alias = get_field_and_alias(field) - - value = getattr(obj, alias, None) - - if typ is None: - data[field] = typ - continue - - optional, typ = get_optional(typ) - if not value and not optional: - raise TypedClassException('Missing value for attribute `{}`'.format(field)) - - _, value = cast(typ, value) - data[field] = value - - return data - - -def get_params(obj): - assert(issubclass(obj.__class__, TypedClass)) - - if not hasattr(obj.__class__, '_cached_oop_params'): - base = {} - for cls in reversed(inspect.getmro(obj.__class__)): - base.update(getattr(cls, 'PARAMS', {})) - obj.__class__._cached_oop_params = base - return obj.__class__._cached_oop_params - - -def load_typed_instance(obj, data): - return load_typed_class(obj, get_params(obj), data) - - -class TypedClass(object): - def __init__(self, **kwargs): - # TODO: validate - self.__dict__.update(kwargs) - - @classmethod - def from_dict(cls, data): - self = cls() - load_typed_instance(self, data) - return self - - def to_dict(self): - return dump_typed_class(self, get_params(self)) - - -def require_implementation(attr): - def _f(self, *args, **kwargs): - raise NotImplementedError('{} must implement method {}', self.__class__.__name, attr) - return _f diff --git a/disco/util/websocket.py b/disco/util/websocket.py new file mode 100644 index 0000000..9c6f2ab --- /dev/null +++ b/disco/util/websocket.py @@ -0,0 +1,28 @@ +from __future__ import absolute_import + +import websocket +import gevent +import six + +from disco.util.logging import LoggingClass + + +class Websocket(LoggingClass, websocket.WebSocketApp): + def __init__(self, *args, **kwargs): + LoggingClass.__init__(self) + websocket.WebSocketApp.__init__(self, *args, **kwargs) + def _get_close_args(self, data): + if data and len(data) >= 2: + code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2]) + reason = data[2:].decode('utf-8') + return [code, reason] + return [None, None] + + def _callback(self, callback, *args): + if not callback: + return + + try: + gevent.spawn(callback, self, *args) + except Exception as e: + self.log.exception('Error in Websocket callback for {}: '.format(callback)) diff --git a/disco/voice/__init__.py b/disco/voice/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/disco/voice/client.py b/disco/voice/client.py new file mode 100644 index 0000000..c45c66e --- /dev/null +++ b/disco/voice/client.py @@ -0,0 +1,119 @@ +import gevent + +from holster.enum import Enum +from holster.emitter import Emitter + +from disco.util.websocket import Websocket +from disco.util.logging import LoggingClass +from disco.util.json import loads, dumps +from disco.voice.packets import VoiceOPCode +from disco.gateway.packets import OPCode + +VoiceState = Enum( + DISCONNECTED=0, + AWAITING_ENDPOINT=1, + AUTHENTICATING=2, + CONNECTING=3, + CONNECTED=4, + VOICE_CONNECTING=5, + VOICE_CONNECTED=6, +) + + +class VoiceException(Exception): + def __init__(self, msg, client): + self.voice_client = client + super(VoiceException, self).__init__(msg) + + +class VoiceClient(LoggingClass): + def __init__(self, channel): + assert(channel.is_voice) + self.channel = channel + self.client = self.channel.client + + self.packets = Emitter(gevent.spawn) + self.packets.on(VoiceOPCode.READY, self.on_voice_ready) + self.packets.on(VoiceOPCode.SESSION_DESCRIPTION, self.on_voice_sdp) + + # State + self.state = VoiceState.DISCONNECTED + self.connected = gevent.event.Event() + self.token = None + self.endpoint = None + + self.update_listener = None + + # Websocket connection + self.ws = None + + def send(self, op, data): + self.ws.send(dumps({ + 'op': op.value, + 'd': data, + })) + + def on_voice_ready(self, data): + print data + + def on_voice_sdp(self, data): + print data + + def on_voice_server_update(self, data): + if self.channel.guild_id != data.guild_id or not data.token: + return + + if self.token and self.token != data.token: + return + + self.token = data.token + self.state = VoiceState.AUTHENTICATING + + self.endpoint = 'wss://{}'.format(data.endpoint.split(':', 1)[0]) + self.ws = Websocket( + self.endpoint, + on_message=self.on_message, + on_error=self.on_error, + on_open=self.on_open, + on_close=self.on_close, + ) + self.ws.run_forever() + + def on_message(self, ws, msg): + try: + data = loads(msg) + except: + self.log.exception('Failed to parse voice gateway message: ') + + self.packets.emit(VoiceOPCode[data['op']], data) + + def on_error(self, ws, err): + # TODO + self.log.warning('Voice websocket error: {}'.format(err)) + + def on_open(self, ws): + self.send(VoiceOPCode.IDENTIFY, { + 'server_id': self.channel.guild_id, + 'user_id': self.client.state.me.id, + 'session_id': self.client.gw.session_id, + 'token': self.token + }) + + def on_close(self, ws): + # TODO + self.log.warning('Voice websocket disconnected') + + def connect(self, timeout=5, mute=False, deaf=False): + self.state = VoiceState.AWAITING_ENDPOINT + + self.update_listener = self.client.events.on('VoiceServerUpdate', self.on_voice_server_update) + + self.client.gw.send(OPCode.VOICE_STATE_UPDATE, { + 'self_mute': mute, + 'self_deaf': deaf, + 'guild_id': int(self.channel.guild_id), + 'channel_id': int(self.channel.id), + }) + + if not self.connected.wait(timeout) or self.state != VoiceState.CONNECTED: + raise VoiceException('Failed to connect to voice', self) diff --git a/disco/voice/packets.py b/disco/voice/packets.py new file mode 100644 index 0000000..95d6118 --- /dev/null +++ b/disco/voice/packets.py @@ -0,0 +1,10 @@ +from holster.enum import Enum + +VoiceOPCode = Enum( + IDENTIFY=0, + SELECT_PROTOCOL=1, + READY=2, + HEARTBEAT=3, + SESSION_DESCRIPTION=4, + SPEAKING=5, +) diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py index c70e7c9..4bfd16f 100644 --- a/examples/basic_plugin.py +++ b/examples/basic_plugin.py @@ -49,6 +49,17 @@ class BasicPlugin(Plugin): '\n'.join([str(i.id) for i in self.state.messages[event.channel.id]]) )) + @Plugin.command('airhorn') + def on_airhorn(self, event): + vs = event.member.get_voice_state() + if not vs: + event.msg.reply('You are not connected to voice') + return + + print vs.channel + print vs.channel_id + print vs.channel.connect() + if __name__ == '__main__': bot = Bot(disco_main()) bot.add_plugin(BasicPlugin)