Browse Source

Refactoring, start work on voice support

pull/3/head
Andrei 9 years ago
parent
commit
13ee9eae09
  1. 4
      disco/bot/command.py
  2. 1
      disco/client.py
  3. 96
      disco/gateway/client.py
  4. 85
      disco/gateway/packets.py
  5. 25
      disco/state.py
  6. 7
      disco/types/base.py
  7. 14
      disco/types/channel.py
  8. 36
      disco/types/guild.py
  9. 3
      disco/types/user.py
  10. 24
      disco/types/voice.py
  11. 11
      disco/util/__init__.py
  12. 11
      disco/util/cache.py
  13. 11
      disco/util/json.py
  14. 139
      disco/util/oop.py
  15. 28
      disco/util/websocket.py
  16. 0
      disco/voice/__init__.py
  17. 119
      disco/voice/client.py
  18. 10
      disco/voice/packets.py
  19. 11
      examples/basic_plugin.py

4
disco/bot/command.py

@ -15,6 +15,10 @@ class CommandEvent(object):
self.name = self.match.group(1) self.name = self.match.group(1)
self.args = self.match.group(2).strip().split(' ') self.args = self.match.group(2).strip().split(' ')
@cached_property
def member(self):
return self.guild.get_member(self.actor)
@property @property
def channel(self): def channel(self):
return self.msg.channel return self.msg.channel

1
disco/client.py

@ -17,6 +17,7 @@ class DiscoClient(object):
self.sharding = sharding or {'number': 0, 'total': 1} self.sharding = sharding or {'number': 0, 'total': 1}
self.events = Emitter(gevent.spawn) self.events = Emitter(gevent.spawn)
self.packets = Emitter(gevent.spawn)
self.state = State(self) self.state = State(self)
self.api = APIClient(self) self.api = APIClient(self)

96
disco/gateway/client.py

@ -1,27 +1,17 @@
import websocket
import gevent import gevent
import json
import zlib import zlib
import six
import ssl import ssl
from disco.gateway.packets import OPCode, HeartbeatPacket, ResumePacket, IdentifyPacket from disco.gateway.packets import OPCode
from disco.gateway.events import GatewayEvent from disco.gateway.events import GatewayEvent
from disco.util.json import loads, dumps
from disco.util.websocket import Websocket
from disco.util.logging import LoggingClass from disco.util.logging import LoggingClass
GATEWAY_VERSION = 6 GATEWAY_VERSION = 6
TEN_MEGABYTES = 10490000 TEN_MEGABYTES = 10490000
# Hack to get websocket close information
def websocket_get_close_args_override(data):
if data and len(data) >= 2:
code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2])
reason = data[2:].decode('utf-8')
return [code, reason]
return [None, None]
class GatewayClient(LoggingClass): class GatewayClient(LoggingClass):
MAX_RECONNECTS = 5 MAX_RECONNECTS = 5
@ -29,7 +19,19 @@ class GatewayClient(LoggingClass):
super(GatewayClient, self).__init__() super(GatewayClient, self).__init__()
self.client = client self.client = client
self.client.events.on('Ready', self.on_ready) self.events = client.events
self.packets = client.packets
# Create emitter and bind to gateway payloads
self.packets.on(OPCode.DISPATCH, self.handle_dispatch)
self.packets.on(OPCode.HEARTBEAT, self.handle_heartbeat)
self.packets.on(OPCode.RECONNECT, self.handle_reconnect)
self.packets.on(OPCode.INVALID_SESSION, self.handle_invalid_session)
self.packets.on(OPCode.HELLO, self.handle_hello)
self.packets.on(OPCode.HEARTBEAT_ACK, self.handle_heartbeat_ack)
# Bind to ready payload
self.events.on('Ready', self.on_ready)
# Websocket connection # Websocket connection
self.ws = None self.ws = None
@ -46,15 +48,15 @@ class GatewayClient(LoggingClass):
# Heartbeat # Heartbeat
self._heartbeat_task = None self._heartbeat_task = None
def send(self, packet): def send(self, op, data):
self.ws.send(json.dumps({ self.ws.send(dumps({
'op': int(packet.OP), 'op': op.value,
'd': packet.to_dict(), 'd': data,
})) }))
def heartbeat_task(self, interval): def heartbeat_task(self, interval):
while True: while True:
self.send(HeartbeatPacket(data=self.seq)) self.send(OPCode.HEARTBEAT, self.seq)
gevent.sleep(interval / 1000) gevent.sleep(interval / 1000)
def handle_dispatch(self, packet): def handle_dispatch(self, packet):
@ -63,7 +65,7 @@ class GatewayClient(LoggingClass):
self.client.events.emit(obj.__class__.__name__, obj) self.client.events.emit(obj.__class__.__name__, obj)
def handle_heartbeat(self, packet): def handle_heartbeat(self, packet):
self.send(HeartbeatPacket(data=self.seq)) self.send(OPCode.HEARTBEAT, self.seq)
def handle_reconnect(self, packet): def handle_reconnect(self, packet):
self.log.warning('Received RECONNECT request, forcing a fresh reconnect') self.log.warning('Received RECONNECT request, forcing a fresh reconnect')
@ -92,14 +94,13 @@ class GatewayClient(LoggingClass):
self._cached_gateway_url = self.client.api.gateway(version=GATEWAY_VERSION, encoding='json') self._cached_gateway_url = self.client.api.gateway(version=GATEWAY_VERSION, encoding='json')
self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url) self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url)
self.ws = websocket.WebSocketApp( self.ws = Websocket(
self._cached_gateway_url, self._cached_gateway_url,
on_message=self.log_on_error('Error in on_message:', self.on_message), on_message=self.on_message,
on_error=self.log_on_error('Error in on_error:', self.on_error), on_error=self.on_error,
on_open=self.log_on_error('Error in on_open:', self.on_open), on_open=self.on_open,
on_close=self.log_on_error('Error in on_close:', self.on_close), on_close=self.on_close,
) )
self.ws._get_close_args = websocket_get_close_args_override
self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def on_message(self, ws, msg): def on_message(self, ws, msg):
@ -108,29 +109,17 @@ class GatewayClient(LoggingClass):
msg = zlib.decompress(msg, 15, TEN_MEGABYTES) msg = zlib.decompress(msg, 15, TEN_MEGABYTES)
try: try:
data = json.loads(msg) data = loads(msg)
except: except:
self.log.exception('Failed to load dispatch:') self.log.exception('Failed to parse gateway message: ')
return return
# Update sequence # Update sequence
if data['s'] and data['s'] > self.seq: if data['s'] and data['s'] > self.seq:
self.seq = data['s'] self.seq = data['s']
if data['op'] == OPCode.DISPATCH: # Emit packet
self.handle_dispatch(data) self.packets.emit(OPCode[data['op']], data)
elif data['op'] == OPCode.HEARTBEAT:
self.handle_heartbeat(data)
elif data['op'] == OPCode.RECONNECT:
self.handle_reconnect(data)
elif data['op'] == OPCode.INVALID_SESSION:
self.handle_invalid_session(data)
elif data['op'] == OPCode.HELLO:
self.handle_hello(data)
elif data['op'] == OPCode.HEARTBEAT_ACK:
self.handle_heartbeat_ack(data)
else:
raise Exception('Unknown packet: {}'.format(data['op']))
def on_error(self, ws, error): def on_error(self, ws, error):
if isinstance(error, KeyboardInterrupt): if isinstance(error, KeyboardInterrupt):
@ -140,14 +129,25 @@ class GatewayClient(LoggingClass):
def on_open(self, ws): def on_open(self, ws):
if self.seq and self.session_id: if self.seq and self.session_id:
self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq) self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq)
self.send(ResumePacket(seq=self.seq, session_id=self.session_id, token=self.client.token)) self.send(OPCode.RESUME, {
'token': self.client.token,
'session_id': self.session_id,
'seq': self.seq
})
else: else:
self.log.info('WS Opened: sending identify payload') self.log.info('WS Opened: sending identify payload')
self.send(IdentifyPacket( self.send(OPCode.IDENTIFY, {
token=self.client.token, 'token': self.client.token,
compress=True, 'compress': True,
large_threshold=250, 'large_threshold': 250,
shard=[self.client.sharding['number'], self.client.sharding['total']])) 'shard': [self.client.sharding['number'], self.client.sharding['total']],
'properties': {
'$os': 'linux',
'$browser': 'disco',
'$device': 'disco',
'$referrer': '',
}
})
def on_close(self, ws, code, reason): def on_close(self, ws, code, reason):
if self.shutting_down: if self.shutting_down:

85
disco/gateway/packets.py

@ -1,7 +1,5 @@
from holster.enum import Enum from holster.enum import Enum
from disco.util.oop import TypedClass
OPCode = Enum( OPCode = Enum(
DISPATCH=0, DISPATCH=0,
HEARTBEAT=1, HEARTBEAT=1,
@ -17,86 +15,3 @@ OPCode = Enum(
HEARTBEAT_ACK=11, HEARTBEAT_ACK=11,
GUILD_SYNC=12, GUILD_SYNC=12,
) )
class Packet(TypedClass):
pass
class DispatchPacket(Packet):
OP = OPCode.DISPATCH
PARAMS = {
('d', 'data'): {},
('t', 'event'): str,
}
class HeartbeatPacket(Packet):
OP = OPCode.HEARTBEAT
PARAMS = {
('d', 'data'): (int, ),
}
class IdentifyPacket(Packet):
OP = OPCode.IDENTIFY
PARAMS = {
'token': str,
'compress': bool,
'large_threshold': int,
'shard': [int],
'properties': 'properties'
}
@property
def properties(self):
return {
'$os': 'linux',
'$browser': 'disco',
'$device': 'disco',
'$referrer': '',
}
class ResumePacket(Packet):
OP = OPCode.RESUME
PARAMS = {
'token': str,
'session_id': str,
'seq': int,
}
class ReconnectPacket(Packet):
OP = OPCode.RECONNECT
class InvalidSessionPacket(Packet):
OP = OPCode.INVALID_SESSION
class HelloPacket(Packet):
OP = OPCode.HELLO
PARAMS = {
'heartbeat_interval': int,
'_trace': [str],
}
class HeartbeatAckPacket(Packet):
OP = OPCode.HEARTBEAT_ACK
PACKETS = {
int(OPCode.DISPATCH): DispatchPacket,
int(OPCode.HEARTBEAT): HeartbeatPacket,
int(OPCode.RECONNECT): ReconnectPacket,
int(OPCode.INVALID_SESSION): InvalidSessionPacket,
int(OPCode.HELLO): HelloPacket,
int(OPCode.HEARTBEAT_ACK): HeartbeatAckPacket,
}

25
disco/state.py

@ -23,10 +23,11 @@ class State(object):
self.dms = {} self.dms = {}
self.guilds = {} self.guilds = {}
self.channels = WeakValueDictionary() self.channels = WeakValueDictionary()
self.users = WeakValueDictionary()
self.client.events.on('Ready', self.on_ready) self.client.events.on('Ready', self.on_ready)
self.messages_stack = defaultdict(lambda: deque(maxlen=self.config.track_messages_size)) self.messages = defaultdict(lambda: deque(maxlen=self.config.track_messages_size))
if self.config.track_messages: if self.config.track_messages:
self.client.events.on('MessageCreate', self.on_message_create) self.client.events.on('MessageCreate', self.on_message_create)
self.client.events.on('MessageDelete', self.on_message_delete) self.client.events.on('MessageDelete', self.on_message_delete)
@ -45,15 +46,15 @@ class State(object):
self.me = event.user self.me = event.user
def on_message_create(self, event): def on_message_create(self, event):
self.messages_stack[event.message.channel_id].append( self.messages[event.message.channel_id].append(
StackMessage(event.message.id, event.message.channel_id, event.message.author.id)) StackMessage(event.message.id, event.message.channel_id, event.message.author.id))
def on_message_update(self, event): def on_message_update(self, event):
message, cid = event.message, event.message.channel_id message, cid = event.message, event.message.channel_id
if cid not in self.messages_stack: if cid not in self.messages:
return return
sm = next((i for i in self.messages_stack[cid] if i.id == message.id), None) sm = next((i for i in self.messages[cid] if i.id == message.id), None)
if not sm: if not sm:
return return
@ -62,21 +63,21 @@ class State(object):
sm.author_id = message.author.id sm.author_id = message.author.id
def on_message_delete(self, event): def on_message_delete(self, event):
if event.channel_id not in self.messages_stack: if event.channel_id not in self.messages:
return return
sm = next((i for i in self.messages_stack[event.channel_id] if i.id == event.id), None) sm = next((i for i in self.messages[event.channel_id] if i.id == event.id), None)
if not sm: if not sm:
return return
self.messages_stack[event.channel_id].remove(sm) self.messages[event.channel_id].remove(sm)
def on_guild_create(self, event): def on_guild_create(self, event):
self.guilds[event.guild.id] = event.guild self.guilds[event.guild.id] = event.guild
self.channels.update(event.guild.channels) self.channels.update(event.guild.channels)
def on_guild_update(self, event): def on_guild_update(self, event):
self.guilds[event.guild.id] = event.guild self.guilds[event.guild.id].update(event.guild)
def on_guild_delete(self, event): def on_guild_delete(self, event):
if event.guild_id in self.guilds: if event.guild_id in self.guilds:
@ -92,12 +93,8 @@ class State(object):
self.channels[event.channel.id] = event.channel self.channels[event.channel.id] = event.channel
def on_channel_update(self, event): def on_channel_update(self, event):
if event.channel.is_guild and event.channel.guild_id in self.guilds: if event.channel.id in self.channels:
self.guilds[event.channel.id] = event.channel self.channels[event.channel.id].update(event.channel)
self.channels[event.channel.id] = event.channel
elif event.channel.is_dm:
self.dms[event.channel.id] = event.channel
self.channels[event.channel.id] = event.channel
def on_channel_delete(self, event): def on_channel_delete(self, event):
if event.channel.is_guild and event.channel.guild_id in self.guilds: if event.channel.is_guild and event.channel.guild_id in self.guilds:

7
disco/types/base.py

@ -5,6 +5,12 @@ from disco.util import skema_find_recursive_by_type
class BaseType(skema.Model): class BaseType(skema.Model):
def on_create(self):
pass
def update(self, other):
self.__dict__.update(other.__dict__)
@classmethod @classmethod
def create(cls, client, data): def create(cls, client, data):
obj = cls(data) obj = cls(data)
@ -16,6 +22,7 @@ class BaseType(skema.Model):
item.client = client item.client = client
obj.client = client obj.client = client
obj.on_create()
return obj return obj
@classmethod @classmethod

14
disco/types/channel.py

@ -6,6 +6,7 @@ from disco.util.cache import cached_property
from disco.util.types import ListToDictType from disco.util.types import ListToDictType
from disco.types.base import BaseType from disco.types.base import BaseType
from disco.types.user import User from disco.types.user import User
from disco.voice.client import VoiceClient
ChannelType = Enum( ChannelType = Enum(
@ -52,11 +53,15 @@ class Channel(BaseType):
def is_dm(self): def is_dm(self):
return self.type in (ChannelType.DM, ChannelType.GROUP_DM) return self.type in (ChannelType.DM, ChannelType.GROUP_DM)
@property
def is_voice(self):
return self.type in (ChannelType.GUILD_VOICE, ChannelType.GROUP_DM)
@property @property
def last_message_id(self): def last_message_id(self):
if self.id not in self.client.state.messages_stack: if self.id not in self.client.state.messages:
return self._last_message_id return self._last_message_id
return self.client.state.messages_stack[self.id][-1].id return self.client.state.messages[self.id][-1].id
@property @property
def messages(self): def messages(self):
@ -78,6 +83,11 @@ class Channel(BaseType):
def send_message(self, content, nonce=None, tts=False): def send_message(self, content, nonce=None, tts=False):
return self.client.api.channels_messages_create(self.id, content, nonce, tts) return self.client.api.channels_messages_create(self.id, content, nonce, tts)
def connect(self, *args, **kwargs):
vc = VoiceClient(self)
vc.connect(*args, **kwargs)
return vc
class MessageIterator(object): class MessageIterator(object):
Direction = Enum('UP', 'DOWN') Direction = Enum('UP', 'DOWN')

36
disco/types/guild.py

@ -1,5 +1,7 @@
import skema import skema
from disco.api.http import APIException
from disco.util import to_snowflake
from disco.types.base import BaseType from disco.types.base import BaseType
from disco.util.types import PreHookType, ListToDictType from disco.util.types import PreHookType, ListToDictType
from disco.types.user import User from disco.types.user import User
@ -32,6 +34,15 @@ class GuildMember(BaseType):
joined_at = PreHookType(lambda k: k[:-6], skema.DateTimeType()) joined_at = PreHookType(lambda k: k[:-6], skema.DateTimeType())
roles = skema.ListType(skema.SnowflakeType()) roles = skema.ListType(skema.SnowflakeType())
def get_voice_state(self):
return self.guild.get_voice_state(self)
def kick(self):
self.client.api.guilds_members_kick(self.guild.id, self.user.id)
def ban(self, delete_message_days=0):
self.client.api.guilds_bans_create(self.guild.id, self.user.id, delete_message_days)
@property @property
def id(self): def id(self):
return self.user.id return self.user.id
@ -60,12 +71,33 @@ class Guild(BaseType):
channels = ListToDictType('id', skema.ModelType(Channel)) channels = ListToDictType('id', skema.ModelType(Channel))
roles = ListToDictType('id', skema.ModelType(Role)) roles = ListToDictType('id', skema.ModelType(Role))
emojis = ListToDictType('id', skema.ModelType(Emoji)) emojis = ListToDictType('id', skema.ModelType(Emoji))
voice_states = ListToDictType('id', skema.ModelType(VoiceState)) voice_states = ListToDictType('session_id', skema.ModelType(VoiceState))
def get_voice_state(self, user):
user = to_snowflake(user)
for state in self.voice_states.values():
if state.user_id == user:
return state
def get_member(self, user): def get_member(self, user):
return self.members.get(user.id) user = to_snowflake(user)
if user not in self.members:
try:
self.members[user] = self.client.api.guilds_members_get(self.id, user)
except APIException:
pass
return self.members.get(user)
def validate_members(self, ctx):
if self.members:
for member in self.members.values():
member.guild = self
def validate_channels(self, ctx): def validate_channels(self, ctx):
if self.channels: if self.channels:
for channel in self.channels.values(): for channel in self.channels.values():
channel.guild_id = self.id channel.guild_id = self.id
channel.guild = self

3
disco/types/user.py

@ -12,3 +12,6 @@ class User(BaseType):
verified = skema.BooleanType(required=False) verified = skema.BooleanType(required=False)
email = skema.EmailType(required=False) email = skema.EmailType(required=False)
def on_create(self):
self.client.state.users[self.id] = self

24
disco/types/voice.py

@ -4,4 +4,26 @@ from disco.types.base import BaseType
class VoiceState(BaseType): class VoiceState(BaseType):
id = skema.SnowflakeType() session_id = skema.StringType()
guild_id = skema.SnowflakeType()
channel_id = skema.SnowflakeType()
user_id = skema.SnowflakeType()
deaf = skema.BooleanType()
mute = skema.BooleanType()
self_deaf = skema.BooleanType()
self_mute = skema.BooleanType()
suppress = skema.BooleanType()
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@property
def channel(self):
return self.client.state.channels.get(self.channel_id)
@property
def user(self):
return self.client.state.users.get(self.user_id)

11
disco/util/__init__.py

@ -1,6 +1,17 @@
import skema import skema
def to_snowflake(i):
if isinstance(i, long):
return i
elif isinstance(i, str):
return long(i)
elif hasattr(i, 'id'):
return i.id
raise Exception('{} ({}) is not convertable to a snowflake'.format(type(i), i))
def _recurse(typ, field, value): def _recurse(typ, field, value):
result = [] result = []

11
disco/util/cache.py

@ -1,8 +1,15 @@
def cached_property(f): def cached_property(f):
def deco(self, *args, **kwargs): def getf(self, *args, **kwargs):
if not hasattr(self, '__' + f.__name__): if not hasattr(self, '__' + f.__name__):
setattr(self, '__' + f.__name__, f(self, *args, **kwargs)) setattr(self, '__' + f.__name__, f(self, *args, **kwargs))
return getattr(self, '__' + f.__name__) return getattr(self, '__' + f.__name__)
return property(deco)
def setf(self, value):
setattr(self, '__' + f.__name__, value)
def delf(self):
setattr(self, '__' + f.__name__, None)
return property(getf, setf, delf)

11
disco/util/json.py

@ -0,0 +1,11 @@
from __future__ import absolute_import
from json import dumps
try:
from rapidjson import loads
except ImportError:
print '[WARNING] rapidjson not installed, falling back to default Python JSON parser'
from json import loads
__all__ = ['dumps', 'loads']

139
disco/util/oop.py

@ -1,139 +0,0 @@
import inspect
class TypedClassException(Exception):
pass
def construct_typed_class(cls, data):
obj = cls()
load_typed_class(obj, data)
return obj
def get_field_and_alias(field):
if isinstance(field, tuple):
return field
else:
return field, field
def get_optional(typ):
if isinstance(typ, tuple) and len(typ) == 1:
return True, typ[0]
return False, typ
def cast(typ, value):
valid = True
# TODO: better exceptions
if isinstance(typ, list):
if typ:
typ = typ[0]
value = map(typ, value)
else:
list(value)
elif isinstance(typ, dict):
if typ:
ktyp, vtyp = typ.items()[0]
value = {ktyp(k): vtyp(v) for k, v in typ.items()}
else:
dict(value)
elif isinstance(typ, set):
if typ:
typ = list(typ)[0]
value = set(map(typ, value))
else:
set(value)
elif isinstance(typ, str):
valid = False
elif not isinstance(value, typ):
value = typ(value)
return valid, value
def load_typed_class(obj, params, data):
print obj, params, data
for field, typ in params.items():
field, alias = get_field_and_alias(field)
# Skipped field
if typ is None:
continue
optional, typ = get_optional(typ)
if field not in data and not optional:
raise TypedClassException('Missing value for attribute `{}`'.format(field))
value = data[field]
print field, alias, value, typ
if value is None:
if not optional:
raise TypedClassException('Non-optional attribute `{}` cannot take None'.format(field))
else:
valid, value = cast(typ, value)
if not valid:
continue
setattr(obj, alias, value)
def dump_typed_class(obj, params):
data = {}
for field, typ in params.items():
field, alias = get_field_and_alias(field)
value = getattr(obj, alias, None)
if typ is None:
data[field] = typ
continue
optional, typ = get_optional(typ)
if not value and not optional:
raise TypedClassException('Missing value for attribute `{}`'.format(field))
_, value = cast(typ, value)
data[field] = value
return data
def get_params(obj):
assert(issubclass(obj.__class__, TypedClass))
if not hasattr(obj.__class__, '_cached_oop_params'):
base = {}
for cls in reversed(inspect.getmro(obj.__class__)):
base.update(getattr(cls, 'PARAMS', {}))
obj.__class__._cached_oop_params = base
return obj.__class__._cached_oop_params
def load_typed_instance(obj, data):
return load_typed_class(obj, get_params(obj), data)
class TypedClass(object):
def __init__(self, **kwargs):
# TODO: validate
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, data):
self = cls()
load_typed_instance(self, data)
return self
def to_dict(self):
return dump_typed_class(self, get_params(self))
def require_implementation(attr):
def _f(self, *args, **kwargs):
raise NotImplementedError('{} must implement method {}', self.__class__.__name, attr)
return _f

28
disco/util/websocket.py

@ -0,0 +1,28 @@
from __future__ import absolute_import
import websocket
import gevent
import six
from disco.util.logging import LoggingClass
class Websocket(LoggingClass, websocket.WebSocketApp):
def __init__(self, *args, **kwargs):
LoggingClass.__init__(self)
websocket.WebSocketApp.__init__(self, *args, **kwargs)
def _get_close_args(self, data):
if data and len(data) >= 2:
code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2])
reason = data[2:].decode('utf-8')
return [code, reason]
return [None, None]
def _callback(self, callback, *args):
if not callback:
return
try:
gevent.spawn(callback, self, *args)
except Exception as e:
self.log.exception('Error in Websocket callback for {}: '.format(callback))

0
disco/voice/__init__.py

119
disco/voice/client.py

@ -0,0 +1,119 @@
import gevent
from holster.enum import Enum
from holster.emitter import Emitter
from disco.util.websocket import Websocket
from disco.util.logging import LoggingClass
from disco.util.json import loads, dumps
from disco.voice.packets import VoiceOPCode
from disco.gateway.packets import OPCode
VoiceState = Enum(
DISCONNECTED=0,
AWAITING_ENDPOINT=1,
AUTHENTICATING=2,
CONNECTING=3,
CONNECTED=4,
VOICE_CONNECTING=5,
VOICE_CONNECTED=6,
)
class VoiceException(Exception):
def __init__(self, msg, client):
self.voice_client = client
super(VoiceException, self).__init__(msg)
class VoiceClient(LoggingClass):
def __init__(self, channel):
assert(channel.is_voice)
self.channel = channel
self.client = self.channel.client
self.packets = Emitter(gevent.spawn)
self.packets.on(VoiceOPCode.READY, self.on_voice_ready)
self.packets.on(VoiceOPCode.SESSION_DESCRIPTION, self.on_voice_sdp)
# State
self.state = VoiceState.DISCONNECTED
self.connected = gevent.event.Event()
self.token = None
self.endpoint = None
self.update_listener = None
# Websocket connection
self.ws = None
def send(self, op, data):
self.ws.send(dumps({
'op': op.value,
'd': data,
}))
def on_voice_ready(self, data):
print data
def on_voice_sdp(self, data):
print data
def on_voice_server_update(self, data):
if self.channel.guild_id != data.guild_id or not data.token:
return
if self.token and self.token != data.token:
return
self.token = data.token
self.state = VoiceState.AUTHENTICATING
self.endpoint = 'wss://{}'.format(data.endpoint.split(':', 1)[0])
self.ws = Websocket(
self.endpoint,
on_message=self.on_message,
on_error=self.on_error,
on_open=self.on_open,
on_close=self.on_close,
)
self.ws.run_forever()
def on_message(self, ws, msg):
try:
data = loads(msg)
except:
self.log.exception('Failed to parse voice gateway message: ')
self.packets.emit(VoiceOPCode[data['op']], data)
def on_error(self, ws, err):
# TODO
self.log.warning('Voice websocket error: {}'.format(err))
def on_open(self, ws):
self.send(VoiceOPCode.IDENTIFY, {
'server_id': self.channel.guild_id,
'user_id': self.client.state.me.id,
'session_id': self.client.gw.session_id,
'token': self.token
})
def on_close(self, ws):
# TODO
self.log.warning('Voice websocket disconnected')
def connect(self, timeout=5, mute=False, deaf=False):
self.state = VoiceState.AWAITING_ENDPOINT
self.update_listener = self.client.events.on('VoiceServerUpdate', self.on_voice_server_update)
self.client.gw.send(OPCode.VOICE_STATE_UPDATE, {
'self_mute': mute,
'self_deaf': deaf,
'guild_id': int(self.channel.guild_id),
'channel_id': int(self.channel.id),
})
if not self.connected.wait(timeout) or self.state != VoiceState.CONNECTED:
raise VoiceException('Failed to connect to voice', self)

10
disco/voice/packets.py

@ -0,0 +1,10 @@
from holster.enum import Enum
VoiceOPCode = Enum(
IDENTIFY=0,
SELECT_PROTOCOL=1,
READY=2,
HEARTBEAT=3,
SESSION_DESCRIPTION=4,
SPEAKING=5,
)

11
examples/basic_plugin.py

@ -49,6 +49,17 @@ class BasicPlugin(Plugin):
'\n'.join([str(i.id) for i in self.state.messages[event.channel.id]]) '\n'.join([str(i.id) for i in self.state.messages[event.channel.id]])
)) ))
@Plugin.command('airhorn')
def on_airhorn(self, event):
vs = event.member.get_voice_state()
if not vs:
event.msg.reply('You are not connected to voice')
return
print vs.channel
print vs.channel_id
print vs.channel.connect()
if __name__ == '__main__': if __name__ == '__main__':
bot = Bot(disco_main()) bot = Bot(disco_main())
bot.add_plugin(BasicPlugin) bot.add_plugin(BasicPlugin)

Loading…
Cancel
Save