Browse Source

Merge branch 'master' into feature/voice

feature/voice
Andrei 8 years ago
parent
commit
027ecbf9e0
  1. 3
      README.md
  2. 2
      disco/__init__.py
  3. 151
      disco/api/client.py
  4. 82
      disco/api/http.py
  5. 21
      disco/api/ratelimit.py
  6. 117
      disco/bot/bot.py
  7. 140
      disco/bot/command.py
  8. 76
      disco/bot/parser.py
  9. 187
      disco/bot/plugin.py
  10. 1
      disco/bot/providers/disk.py
  11. 25
      disco/bot/providers/redis.py
  12. 6
      disco/bot/providers/rocksdb.py
  13. 17
      disco/cli.py
  14. 45
      disco/client.py
  15. 52
      disco/gateway/client.py
  16. 4
      disco/gateway/encoding/base.py
  17. 2
      disco/gateway/encoding/json.py
  18. 424
      disco/gateway/events.py
  19. 91
      disco/gateway/ipc.py
  20. 4
      disco/gateway/packets.py
  21. 104
      disco/gateway/sharder.py
  22. 68
      disco/state.py
  23. 1
      disco/types/__init__.py
  24. 272
      disco/types/base.py
  25. 102
      disco/types/channel.py
  26. 129
      disco/types/guild.py
  27. 6
      disco/types/invite.py
  28. 200
      disco/types/message.py
  29. 12
      disco/types/permissions.py
  30. 40
      disco/types/user.py
  31. 2
      disco/types/voice.py
  32. 6
      disco/types/webhook.py
  33. 2
      disco/util/config.py
  34. 4
      disco/util/hashmap.py
  35. 3
      disco/util/limiter.py
  36. 35
      disco/util/logging.py
  37. 36
      disco/util/serializer.py
  38. 2
      disco/util/snowflake.py
  39. 2
      disco/util/token.py
  40. 7
      disco/voice/client.py
  41. 127
      disco/voice/opus.py
  42. 49
      disco/voice/player.py
  43. 9
      examples/music.py
  44. 2
      requirements.txt

3
README.md

@ -20,6 +20,7 @@ Disco was built to run both as a generic-use library, and a standalone bot toolk
|requests[security]|adds packages for a proper SSL implementation| |requests[security]|adds packages for a proper SSL implementation|
|ujson|faster json parser, improves performance| |ujson|faster json parser, improves performance|
|erlpack|ETF parser, only Python 2.x, run with the --encoder=etf flag| |erlpack|ETF parser, only Python 2.x, run with the --encoder=etf flag|
|gipc|Gevent IPC, required for autosharding|
## Examples ## Examples
@ -48,7 +49,7 @@ class SimplePlugin(Plugin):
Using the default bot configuration, we can now run this script like so: Using the default bot configuration, we can now run this script like so:
`python -m disco.cli --token="MY_DISCORD_TOKEN" --bot --plugin simpleplugin` `python -m disco.cli --token="MY_DISCORD_TOKEN" --run-bot --plugin simpleplugin`
And commands can be triggered by mentioning the bot (configued by the BotConfig.command\_require\_mention flag): And commands can be triggered by mentioning the bot (configued by the BotConfig.command\_require\_mention flag):

2
disco/__init__.py

@ -1 +1 @@
VERSION = '0.0.5' VERSION = '0.0.7'

151
disco/api/client.py

@ -1,11 +1,12 @@
import six import six
import json
from disco.api.http import Routes, HTTPClient from disco.api.http import Routes, HTTPClient
from disco.util.logging import LoggingClass from disco.util.logging import LoggingClass
from disco.types.user import User from disco.types.user import User
from disco.types.message import Message from disco.types.message import Message
from disco.types.guild import Guild, GuildMember, Role from disco.types.guild import Guild, GuildMember, GuildBan, Role, GuildEmoji
from disco.types.channel import Channel from disco.types.channel import Channel
from disco.types.invite import Invite from disco.types.invite import Invite
from disco.types.webhook import Webhook from disco.types.webhook import Webhook
@ -23,18 +24,40 @@ def optional(**kwargs):
class APIClient(LoggingClass): class APIClient(LoggingClass):
""" """
An abstraction over the :class:`disco.api.http.HTTPClient` that composes requests, and fits An abstraction over a :class:`disco.api.http.HTTPClient`, which composes
the models with the returned data. requests from provided data, and fits models with the returned data. The APIClient
is the only path to the API used within models/other interfaces, and it's
the recommended path for all third-party users/implementations.
Args
----
token : str
The Discord authentication token (without prefixes) to be used for all
HTTP requests.
client : Optional[:class:`disco.client.Client`]
The Disco client this APIClient is a member of. This is used when constructing
and fitting models from response data.
Attributes
----------
client : Optional[:class:`disco.client.Client`]
The Disco client this APIClient is a member of.
http : :class:`disco.http.HTTPClient`
The HTTPClient this APIClient uses for all requests.
""" """
def __init__(self, client): def __init__(self, token, client=None):
super(APIClient, self).__init__() super(APIClient, self).__init__()
self.client = client self.client = client
self.http = HTTPClient(self.client.config.token) self.http = HTTPClient(token)
def gateway(self, version, encoding): def gateway_get(self):
data = self.http(Routes.GATEWAY_GET).json() data = self.http(Routes.GATEWAY_GET).json()
return data['url'] + '?v={}&encoding={}'.format(version, encoding) return data
def gateway_bot_get(self):
data = self.http(Routes.GATEWAY_BOT_GET).json()
return data
def channels_get(self, channel): def channels_get(self, channel):
r = self.http(Routes.CHANNELS_GET, dict(channel=channel)) r = self.http(Routes.CHANNELS_GET, dict(channel=channel))
@ -48,6 +71,9 @@ class APIClient(LoggingClass):
r = self.http(Routes.CHANNELS_DELETE, dict(channel=channel)) r = self.http(Routes.CHANNELS_DELETE, dict(channel=channel))
return Channel.create(self.client, r.json()) return Channel.create(self.client, r.json())
def channels_typing(self, channel):
self.http(Routes.CHANNELS_TYPING, dict(channel=channel))
def channels_messages_list(self, channel, around=None, before=None, after=None, limit=50): def channels_messages_list(self, channel, around=None, before=None, after=None, limit=50):
r = self.http(Routes.CHANNELS_MESSAGES_LIST, dict(channel=channel), params=optional( r = self.http(Routes.CHANNELS_MESSAGES_LIST, dict(channel=channel), params=optional(
around=around, around=around,
@ -62,19 +88,36 @@ class APIClient(LoggingClass):
r = self.http(Routes.CHANNELS_MESSAGES_GET, dict(channel=channel, message=message)) r = self.http(Routes.CHANNELS_MESSAGES_GET, dict(channel=channel, message=message))
return Message.create(self.client, r.json()) return Message.create(self.client, r.json())
def channels_messages_create(self, channel, content, nonce=None, tts=False): def channels_messages_create(self, channel, content, nonce=None, tts=False, attachment=None, embed=None):
r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json={ payload = {
'content': content, 'content': content,
'nonce': nonce, 'nonce': nonce,
'tts': tts, 'tts': tts,
}) }
if embed:
payload['embed'] = embed.to_dict()
if attachment:
r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), data={'payload_json': json.dumps(payload)}, files={
'file': (attachment[0], attachment[1])
})
else:
r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json=payload)
return Message.create(self.client, r.json()) return Message.create(self.client, r.json())
def channels_messages_modify(self, channel, message, content): def channels_messages_modify(self, channel, message, content, embed=None):
payload = {
'content': content,
}
if embed:
payload['embed'] = embed.to_dict()
r = self.http(Routes.CHANNELS_MESSAGES_MODIFY, r = self.http(Routes.CHANNELS_MESSAGES_MODIFY,
dict(channel=channel, message=message), dict(channel=channel, message=message),
json={'content': content}) json=payload)
return Message.create(self.client, r.json()) return Message.create(self.client, r.json())
def channels_messages_delete(self, channel, message): def channels_messages_delete(self, channel, message):
@ -83,6 +126,23 @@ class APIClient(LoggingClass):
def channels_messages_delete_bulk(self, channel, messages): def channels_messages_delete_bulk(self, channel, messages):
self.http(Routes.CHANNELS_MESSAGES_DELETE_BULK, dict(channel=channel), json={'messages': messages}) self.http(Routes.CHANNELS_MESSAGES_DELETE_BULK, dict(channel=channel), json={'messages': messages})
def channels_messages_reactions_get(self, channel, message, emoji):
r = self.http(Routes.CHANNELS_MESSAGES_REACTIONS_GET, dict(channel=channel, message=message, emoji=emoji))
return User.create_map(self.client, r.json())
def channels_messages_reactions_create(self, channel, message, emoji):
self.http(Routes.CHANNELS_MESSAGES_REACTIONS_CREATE, dict(channel=channel, message=message, emoji=emoji))
def channels_messages_reactions_delete(self, channel, message, emoji, user=None):
route = Routes.CHANNELS_MESSAGES_REACTIONS_DELETE_ME
obj = dict(channel=channel, message=message, emoji=emoji)
if user:
route = Routes.CHANNELS_MESSAGES_REACTIONS_DELETE_USER
obj['user'] = user
self.http(route, obj)
def channels_permissions_modify(self, channel, permission, allow, deny, typ): def channels_permissions_modify(self, channel, permission, allow, deny, typ):
self.http(Routes.CHANNELS_PERMISSIONS_MODIFY, dict(channel=channel, permission=permission), json={ self.http(Routes.CHANNELS_PERMISSIONS_MODIFY, dict(channel=channel, permission=permission), json={
'allow': allow, 'allow': allow,
@ -141,10 +201,28 @@ class APIClient(LoggingClass):
def guilds_channels_list(self, guild): def guilds_channels_list(self, guild):
r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild)) r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild))
return Channel.create_map(self.client, r.json(), guild_id=guild) return Channel.create_hash(self.client, 'id', r.json(), guild_id=guild)
def guilds_channels_create(self, guild, **kwargs): def guilds_channels_create(self, guild, name, channel_type, bitrate=None, user_limit=None, permission_overwrites=[]):
r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=kwargs) payload = {
'name': name,
'channel_type': channel_type,
'permission_overwrites': [i.to_dict() for i in permission_overwrites],
}
if channel_type == 'text':
pass
elif channel_type == 'voice':
if bitrate is not None:
payload['bitrate'] = bitrate
if user_limit is not None:
payload['user_limit'] = user_limit
else:
# TODO: better error here?
raise Exception('Invalid channel type: {}'.format(channel_type))
r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=payload)
return Channel.create(self.client, r.json(), guild_id=guild) return Channel.create(self.client, r.json(), guild_id=guild)
def guilds_channels_modify(self, guild, channel, position): def guilds_channels_modify(self, guild, channel, position):
@ -155,21 +233,30 @@ class APIClient(LoggingClass):
def guilds_members_list(self, guild): def guilds_members_list(self, guild):
r = self.http(Routes.GUILDS_MEMBERS_LIST, dict(guild=guild)) r = self.http(Routes.GUILDS_MEMBERS_LIST, dict(guild=guild))
return GuildMember.create_map(self.client, r.json(), guild_id=guild) return GuildMember.create_hash(self.client, 'id', r.json(), guild_id=guild)
def guilds_members_get(self, guild, member): def guilds_members_get(self, guild, member):
r = self.http(Routes.GUILDS_MEMBERS_GET, dict(guild=guild, member=member)) r = self.http(Routes.GUILDS_MEMBERS_GET, dict(guild=guild, member=member))
return GuildMember.create(self.client, r.json(), guild_id=guild) return GuildMember.create(self.client, r.json(), guild_id=guild)
def guilds_members_modify(self, guild, member, **kwargs): def guilds_members_modify(self, guild, member, **kwargs):
self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=kwargs) self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=optional(**kwargs))
def guilds_members_roles_add(self, guild, member, role):
self.http(Routes.GUILDS_MEMBERS_ROLES_ADD, dict(guild=guild, member=member, role=role))
def guilds_members_roles_remove(self, guild, member, role):
self.http(Routes.GUILDS_MEMBERS_ROLES_REMOVE, dict(guild=guild, member=member, role=role))
def guilds_members_me_nick(self, guild, nick):
self.http(Routes.GUILDS_MEMBERS_ME_NICK, dict(guild=guild), json={'nick': nick})
def guilds_members_kick(self, guild, member): def guilds_members_kick(self, guild, member):
self.http(Routes.GUILDS_MEMBERS_KICK, dict(guild=guild, member=member)) self.http(Routes.GUILDS_MEMBERS_KICK, dict(guild=guild, member=member))
def guilds_bans_list(self, guild): def guilds_bans_list(self, guild):
r = self.http(Routes.GUILDS_BANS_LIST, dict(guild=guild)) r = self.http(Routes.GUILDS_BANS_LIST, dict(guild=guild))
return User.create_map(self.client, r.json()) return GuildBan.create_hash(self.client, 'user.id', r.json())
def guilds_bans_create(self, guild, user, delete_message_days): def guilds_bans_create(self, guild, user, delete_message_days):
self.http(Routes.GUILDS_BANS_CREATE, dict(guild=guild, user=user), params={ self.http(Routes.GUILDS_BANS_CREATE, dict(guild=guild, user=user), params={
@ -202,6 +289,28 @@ class APIClient(LoggingClass):
r = self.http(Routes.GUILDS_WEBHOOKS_LIST, dict(guild=guild)) r = self.http(Routes.GUILDS_WEBHOOKS_LIST, dict(guild=guild))
return Webhook.create_map(self.client, r.json()) return Webhook.create_map(self.client, r.json())
def guilds_emojis_list(self, guild):
r = self.http(Routes.GUILDS_EMOJIS_LIST, dict(guild=guild))
return GuildEmoji.create_map(self.client, r.json())
def guilds_emojis_create(self, guild, **kwargs):
r = self.http(Routes.GUILDS_EMOJIS_CREATE, dict(guild=guild), json=kwargs)
return GuildEmoji.create(self.client, r.json())
def guilds_emojis_modify(self, guild, emoji, **kwargs):
r = self.http(Routes.GUILDS_EMOJIS_MODIFY, dict(guild=guild, emoji=emoji), json=kwargs)
return GuildEmoji.create(self.client, r.json())
def guilds_emojis_delete(self, guild, emoji):
self.http(Routes.GUILDS_EMOJIS_DELETE, dict(guild=guild, emoji=emoji))
def users_me_get(self):
return User.create(self.client, self.http(Routes.USERS_ME_GET).json())
def users_me_patch(self, payload):
r = self.http(Routes.USERS_ME_PATCH, json=payload)
return User.create(self.client, r.json())
def invites_get(self, invite): def invites_get(self, invite):
r = self.http(Routes.INVITES_GET, dict(invite=invite)) r = self.http(Routes.INVITES_GET, dict(invite=invite))
return Invite.create(self.client, r.json()) return Invite.create(self.client, r.json())
@ -236,7 +345,7 @@ class APIClient(LoggingClass):
return Webhook.create(self.client, r.json()) return Webhook.create(self.client, r.json())
def webhooks_token_delete(self, webhook, token): def webhooks_token_delete(self, webhook, token):
self.http(Routes.WEBHOOKS_TOKEN_DLEETE, dict(webhook=webhook, token=token)) self.http(Routes.WEBHOOKS_TOKEN_DELETE, dict(webhook=webhook, token=token))
def webhooks_token_execute(self, webhook, token, data, wait=False): def webhooks_token_execute(self, webhook, token, data, wait=False):
obj = self.http( obj = self.http(

82
disco/api/http.py

@ -2,9 +2,12 @@ import requests
import random import random
import gevent import gevent
import six import six
import sys
from holster.enum import Enum from holster.enum import Enum
from disco import VERSION as disco_version
from requests import __version__ as requests_version
from disco.util.logging import LoggingClass from disco.util.logging import LoggingClass
from disco.api.ratelimit import RateLimiter from disco.api.ratelimit import RateLimiter
@ -18,6 +21,12 @@ HTTPMethod = Enum(
) )
def to_bytes(obj):
if isinstance(obj, six.text_type):
return obj.encode('utf-8')
return obj
class Routes(object): class Routes(object):
""" """
Simple Python object-enum of all method/url route combinations available to Simple Python object-enum of all method/url route combinations available to
@ -25,18 +34,25 @@ class Routes(object):
""" """
# Gateway # Gateway
GATEWAY_GET = (HTTPMethod.GET, '/gateway') GATEWAY_GET = (HTTPMethod.GET, '/gateway')
GATEWAY_BOT_GET = (HTTPMethod.GET, '/gateway/bot')
# Channels # Channels
CHANNELS = '/channels/{channel}' CHANNELS = '/channels/{channel}'
CHANNELS_GET = (HTTPMethod.GET, CHANNELS) CHANNELS_GET = (HTTPMethod.GET, CHANNELS)
CHANNELS_MODIFY = (HTTPMethod.PATCH, CHANNELS) CHANNELS_MODIFY = (HTTPMethod.PATCH, CHANNELS)
CHANNELS_DELETE = (HTTPMethod.DELETE, CHANNELS) CHANNELS_DELETE = (HTTPMethod.DELETE, CHANNELS)
CHANNELS_TYPING = (HTTPMethod.POST, CHANNELS + '/typing')
CHANNELS_MESSAGES_LIST = (HTTPMethod.GET, CHANNELS + '/messages') CHANNELS_MESSAGES_LIST = (HTTPMethod.GET, CHANNELS + '/messages')
CHANNELS_MESSAGES_GET = (HTTPMethod.GET, CHANNELS + '/messages/{message}') CHANNELS_MESSAGES_GET = (HTTPMethod.GET, CHANNELS + '/messages/{message}')
CHANNELS_MESSAGES_CREATE = (HTTPMethod.POST, CHANNELS + '/messages') CHANNELS_MESSAGES_CREATE = (HTTPMethod.POST, CHANNELS + '/messages')
CHANNELS_MESSAGES_MODIFY = (HTTPMethod.PATCH, CHANNELS + '/messages/{message}') CHANNELS_MESSAGES_MODIFY = (HTTPMethod.PATCH, CHANNELS + '/messages/{message}')
CHANNELS_MESSAGES_DELETE = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}') CHANNELS_MESSAGES_DELETE = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}')
CHANNELS_MESSAGES_DELETE_BULK = (HTTPMethod.POST, CHANNELS + '/messages/bulk_delete') CHANNELS_MESSAGES_DELETE_BULK = (HTTPMethod.POST, CHANNELS + '/messages/bulk_delete')
CHANNELS_MESSAGES_REACTIONS_GET = (HTTPMethod.GET, CHANNELS + '/messages/{message}/reactions/{emoji}')
CHANNELS_MESSAGES_REACTIONS_CREATE = (HTTPMethod.PUT, CHANNELS + '/messages/{message}/reactions/{emoji}/@me')
CHANNELS_MESSAGES_REACTIONS_DELETE_ME = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}/reactions/{emoji}/@me')
CHANNELS_MESSAGES_REACTIONS_DELETE_USER = (HTTPMethod.DELETE,
CHANNELS + '/messages/{message}/reactions/{emoji}/{user}')
CHANNELS_PERMISSIONS_MODIFY = (HTTPMethod.PUT, CHANNELS + '/permissions/{permission}') CHANNELS_PERMISSIONS_MODIFY = (HTTPMethod.PUT, CHANNELS + '/permissions/{permission}')
CHANNELS_PERMISSIONS_DELETE = (HTTPMethod.DELETE, CHANNELS + '/permissions/{permission}') CHANNELS_PERMISSIONS_DELETE = (HTTPMethod.DELETE, CHANNELS + '/permissions/{permission}')
CHANNELS_INVITES_LIST = (HTTPMethod.GET, CHANNELS + '/invites') CHANNELS_INVITES_LIST = (HTTPMethod.GET, CHANNELS + '/invites')
@ -58,6 +74,9 @@ class Routes(object):
GUILDS_MEMBERS_LIST = (HTTPMethod.GET, GUILDS + '/members') GUILDS_MEMBERS_LIST = (HTTPMethod.GET, GUILDS + '/members')
GUILDS_MEMBERS_GET = (HTTPMethod.GET, GUILDS + '/members/{member}') GUILDS_MEMBERS_GET = (HTTPMethod.GET, GUILDS + '/members/{member}')
GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/members/{member}') GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/members/{member}')
GUILDS_MEMBERS_ROLES_ADD = (HTTPMethod.PUT, GUILDS + '/members/{member}/roles/{role}')
GUILDS_MEMBERS_ROLES_REMOVE = (HTTPMethod.DELETE, GUILDS + '/members/{member}/roles/{role}')
GUILDS_MEMBERS_ME_NICK = (HTTPMethod.PATCH, GUILDS + '/members/@me/nick')
GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, GUILDS + '/members/{member}') GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, GUILDS + '/members/{member}')
GUILDS_BANS_LIST = (HTTPMethod.GET, GUILDS + '/bans') GUILDS_BANS_LIST = (HTTPMethod.GET, GUILDS + '/bans')
GUILDS_BANS_CREATE = (HTTPMethod.PUT, GUILDS + '/bans/{user}') GUILDS_BANS_CREATE = (HTTPMethod.PUT, GUILDS + '/bans/{user}')
@ -79,6 +98,10 @@ class Routes(object):
GUILDS_EMBED_GET = (HTTPMethod.GET, GUILDS + '/embed') GUILDS_EMBED_GET = (HTTPMethod.GET, GUILDS + '/embed')
GUILDS_EMBED_MODIFY = (HTTPMethod.PATCH, GUILDS + '/embed') GUILDS_EMBED_MODIFY = (HTTPMethod.PATCH, GUILDS + '/embed')
GUILDS_WEBHOOKS_LIST = (HTTPMethod.GET, GUILDS + '/webhooks') GUILDS_WEBHOOKS_LIST = (HTTPMethod.GET, GUILDS + '/webhooks')
GUILDS_EMOJIS_LIST = (HTTPMethod.GET, GUILDS + '/emojis')
GUILDS_EMOJIS_CREATE = (HTTPMethod.POST, GUILDS + '/emojis')
GUILDS_EMOJIS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/emojis/{emoji}')
GUILDS_EMOJIS_DELETE = (HTTPMethod.DELETE, GUILDS + '/emojis/{emoji}')
# Users # Users
USERS = '/users' USERS = '/users'
@ -111,14 +134,39 @@ class APIException(Exception):
""" """
Exception thrown when an HTTP-client level error occurs. Usually this will Exception thrown when an HTTP-client level error occurs. Usually this will
be a non-success status-code, or a transient network issue. be a non-success status-code, or a transient network issue.
Attributes
----------
status_code : int
The status code returned by the API for the request that triggered this
error.
""" """
def __init__(self, msg, status_code=0, content=None): def __init__(self, response, retries=None):
self.status_code = status_code self.response = response
self.content = content self.retries = retries
self.msg = msg
self.code = 0
self.msg = 'Request Failed ({})'.format(response.status_code)
if self.status_code: if self.retries:
self.msg += ' code: {}'.format(status_code) self.msg += " after {} retries".format(self.retries)
# Try to decode JSON, and extract params
try:
data = self.response.json()
if 'code' in data:
self.code = data['code']
self.msg = data['message']
elif len(data) == 1:
key, value = list(data.items())[0]
self.msg = 'Request Failed: {}: {}'.format(key, ', '.join(value))
except ValueError:
pass
# DEPRECATED: left for backwards compat
self.status_code = response.status_code
self.content = response.content
super(APIException, self).__init__(self.msg) super(APIException, self).__init__(self.msg)
@ -134,9 +182,18 @@ class HTTPClient(LoggingClass):
def __init__(self, token): def __init__(self, token):
super(HTTPClient, self).__init__() super(HTTPClient, self).__init__()
py_version = '{}.{}.{}'.format(
sys.version_info.major,
sys.version_info.minor,
sys.version_info.micro)
self.limiter = RateLimiter() self.limiter = RateLimiter()
self.headers = { self.headers = {
'Authorization': 'Bot ' + token, 'Authorization': 'Bot ' + token,
'User-Agent': 'DiscordBot (https://github.com/b1naryth1ef/disco {}) Python/{} requests/{}'.format(
disco_version,
py_version,
requests_version),
} }
def __call__(self, route, args=None, **kwargs): def __call__(self, route, args=None, **kwargs):
@ -182,7 +239,8 @@ class HTTPClient(LoggingClass):
kwargs['headers'] = self.headers kwargs['headers'] = self.headers
# Build the bucket URL # Build the bucket URL
filtered = {k: (v if v in ('guild', 'channel') else '') for k, v in six.iteritems(args)} args = {k: to_bytes(v) for k, v in six.iteritems(args)}
filtered = {k: (v if k in ('guild', 'channel') else '') for k, v in six.iteritems(args)}
bucket = (route[0].value, route[1].format(**filtered)) bucket = (route[0].value, route[1].format(**filtered))
# Possibly wait if we're rate limited # Possibly wait if we're rate limited
@ -190,6 +248,7 @@ class HTTPClient(LoggingClass):
# Make the actual request # Make the actual request
url = self.BASE_URL + route[1].format(**args) url = self.BASE_URL + route[1].format(**args)
self.log.info('%s %s (%s)', route[0].value, url, kwargs.get('params'))
r = requests.request(route[0].value, url, **kwargs) r = requests.request(route[0].value, url, **kwargs)
# Update rate limiter # Update rate limiter
@ -198,17 +257,18 @@ class HTTPClient(LoggingClass):
# If we got a success status code, just return the data # If we got a success status code, just return the data
if r.status_code < 400: if r.status_code < 400:
return r return r
elif r.status_code != 429 and 400 < r.status_code < 500: elif r.status_code != 429 and 400 <= r.status_code < 500:
raise APIException('Request failed', r.status_code, r.content) raise APIException(r)
else: else:
if r.status_code == 429: if r.status_code == 429:
self.log.warning('Request responded w/ 429, retrying (but this should not happen, check your clock sync') self.log.warning(
'Request responded w/ 429, retrying (but this should not happen, check your clock sync')
# If we hit the max retries, throw an error # If we hit the max retries, throw an error
retry += 1 retry += 1
if retry > self.MAX_RETRIES: if retry > self.MAX_RETRIES:
self.log.error('Failing request, hit max retries') self.log.error('Failing request, hit max retries')
raise APIException('Request failed after {} attempts'.format(self.MAX_RETRIES), r.status_code, r.content) raise APIException(r, retries=self.MAX_RETRIES)
backoff = self.random_backoff() backoff = self.random_backoff()
self.log.warning('Request to `{}` failed with code {}, retrying after {}s ({})'.format( self.log.warning('Request to `{}` failed with code {}, retrying after {}s ({})'.format(

21
disco/api/ratelimit.py

@ -1,8 +1,10 @@
import time import time
import gevent import gevent
from disco.util.logging import LoggingClass
class RouteState(object):
class RouteState(LoggingClass):
""" """
An object which stores ratelimit state for a given method/url route An object which stores ratelimit state for a given method/url route
combination (as specified in :class:`disco.api.http.Routes`). combination (as specified in :class:`disco.api.http.Routes`).
@ -36,10 +38,13 @@ class RouteState(object):
self.update(response) self.update(response)
def __repr__(self):
return '<RouteState {}>'.format(' '.join(self.route))
@property @property
def chilled(self): def chilled(self):
""" """
Whether this route is currently being cooldown (aka waiting until reset_time) Whether this route is currently being cooldown (aka waiting until reset_time).
""" """
return self.event is not None return self.event is not None
@ -69,7 +74,7 @@ class RouteState(object):
def wait(self, timeout=None): def wait(self, timeout=None):
""" """
Waits until this route is no longer under a cooldown Waits until this route is no longer under a cooldown.
Parameters Parameters
---------- ----------
@ -80,24 +85,26 @@ class RouteState(object):
Returns Returns
------- -------
bool bool
False if the timeout period expired before the cooldown was finished False if the timeout period expired before the cooldown was finished.
""" """
return self.event.wait(timeout) return self.event.wait(timeout)
def cooldown(self): def cooldown(self):
""" """
Waits for the current route to be cooled-down (aka waiting until reset time) Waits for the current route to be cooled-down (aka waiting until reset time).
""" """
if self.reset_time - time.time() < 0: if self.reset_time - time.time() < 0:
raise Exception('Cannot cooldown for negative time period; check clock sync') raise Exception('Cannot cooldown for negative time period; check clock sync')
self.event = gevent.event.Event() self.event = gevent.event.Event()
gevent.sleep((self.reset_time - time.time()) + .5) delay = (self.reset_time - time.time()) + .5
self.log.debug('Cooling down bucket %s for %s seconds', self, delay)
gevent.sleep(delay)
self.event.set() self.event.set()
self.event = None self.event = None
class RateLimiter(object): class RateLimiter(LoggingClass):
""" """
A in-memory store of ratelimit states for all routes we've ever called. A in-memory store of ratelimit states for all routes we've ever called.

117
disco/bot/bot.py

@ -12,6 +12,7 @@ from disco.bot.plugin import Plugin
from disco.bot.command import CommandEvent, CommandLevels from disco.bot.command import CommandEvent, CommandLevels
from disco.bot.storage import Storage from disco.bot.storage import Storage
from disco.util.config import Config from disco.util.config import Config
from disco.util.logging import LoggingClass
from disco.util.serializer import Serializer from disco.util.serializer import Serializer
@ -64,7 +65,7 @@ class BotConfig(Config):
The directory plugin configuration is located within. The directory plugin configuration is located within.
""" """
levels = {} levels = {}
plugins = [] plugin_config = {}
commands_enabled = True commands_enabled = True
commands_require_mention = True commands_require_mention = True
@ -88,7 +89,7 @@ class BotConfig(Config):
storage_config = {} storage_config = {}
class Bot(object): class Bot(LoggingClass):
""" """
Disco's implementation of a simple but extendable Discord bot. Bots consist Disco's implementation of a simple but extendable Discord bot. Bots consist
of a set of plugins, and a Disco client. of a set of plugins, and a Disco client.
@ -114,6 +115,9 @@ class Bot(object):
self.client = client self.client = client
self.config = config or BotConfig() self.config = config or BotConfig()
# Shard manager
self.shards = None
# The context carries information about events in a threadlocal storage # The context carries information about events in a threadlocal storage
self.ctx = ThreadLocal() self.ctx = ThreadLocal()
@ -122,6 +126,7 @@ class Bot(object):
if self.config.storage_enabled: if self.config.storage_enabled:
self.storage = Storage(self.ctx, self.config.from_prefix('storage')) self.storage = Storage(self.ctx, self.config.from_prefix('storage'))
# If the manhole is enabled, add this bot as a local
if self.client.config.manhole_enable: if self.client.config.manhole_enable:
self.client.manhole_locals['bot'] = self self.client.manhole_locals['bot'] = self
@ -135,6 +140,12 @@ class Bot(object):
if self.config.commands_allow_edit: if self.config.commands_allow_edit:
self.client.events.on('MessageUpdate', self.on_message_update) self.client.events.on('MessageUpdate', self.on_message_update)
# If we have a level getter and its a string, try to load it
if isinstance(self.config.commands_level_getter, six.string_types):
mod, func = self.config.commands_level_getter.rsplit('.', 1)
mod = importlib.import_module(mod)
self.config.commands_level_getter = getattr(mod, func)
# Stores the last message for every single channel # Stores the last message for every single channel
self.last_message_cache = {} self.last_message_cache = {}
@ -173,10 +184,10 @@ class Bot(object):
@property @property
def commands(self): def commands(self):
""" """
Generator of all commands this bots plugins have defined Generator of all commands this bots plugins have defined.
""" """
for plugin in six.itervalues(self.plugins): for plugin in six.itervalues(self.plugins):
for command in six.itervalues(plugin.commands): for command in plugin.commands:
yield command yield command
def recompute(self): def recompute(self):
@ -190,7 +201,7 @@ class Bot(object):
def compute_group_abbrev(self): def compute_group_abbrev(self):
""" """
Computes all possible abbreviations for a command grouping Computes all possible abbreviations for a command grouping.
""" """
self.group_abbrev = {} self.group_abbrev = {}
groups = set(command.group for command in self.commands if command.group) groups = set(command.group for command in self.commands if command.group)
@ -199,7 +210,7 @@ class Bot(object):
grp = group grp = group
while grp: while grp:
# If the group already exists, means someone else thought they # If the group already exists, means someone else thought they
# could use it so we need to # could use it so we need yank it from them (and not use it)
if grp in list(six.itervalues(self.group_abbrev)): if grp in list(six.itervalues(self.group_abbrev)):
self.group_abbrev = {k: v for k, v in six.iteritems(self.group_abbrev) if v != grp} self.group_abbrev = {k: v for k, v in six.iteritems(self.group_abbrev) if v != grp}
else: else:
@ -211,13 +222,14 @@ class Bot(object):
""" """
Computes a single regex which matches all possible command combinations. Computes a single regex which matches all possible command combinations.
""" """
re_str = '|'.join(command.regex for command in self.commands) commands = list(self.commands)
re_str = '|'.join(command.regex for command in commands)
if re_str: if re_str:
self.command_matches_re = re.compile(re_str) self.command_matches_re = re.compile(re_str, re.I)
else: else:
self.command_matches_re = None self.command_matches_re = None
def get_commands_for_message(self, msg): def get_commands_for_message(self, require_mention, mention_rules, prefix, msg):
""" """
Generator of all commands that a given message object triggers, based on Generator of all commands that a given message object triggers, based on
the bots plugins and configuration. the bots plugins and configuration.
@ -234,19 +246,19 @@ class Bot(object):
""" """
content = msg.content content = msg.content
if self.config.commands_require_mention: if require_mention:
mention_direct = msg.is_mentioned(self.client.state.me) mention_direct = msg.is_mentioned(self.client.state.me)
mention_everyone = msg.mention_everyone mention_everyone = msg.mention_everyone
mention_roles = [] mention_roles = []
if msg.guild: if msg.guild:
mention_roles = list(filter(lambda r: msg.is_mentioned(r), mention_roles = list(filter(lambda r: msg.is_mentioned(r),
msg.guild.get_member(self.client.state.me).roles)) msg.guild.get_member(self.client.state.me).roles))
if not any(( if not any((
self.config.commands_mention_rules['user'] and mention_direct, mention_rules.get('user', True) and mention_direct,
self.config.commands_mention_rules['everyone'] and mention_everyone, mention_rules.get('everyone', False) and mention_everyone,
self.config.commands_mention_rules['role'] and any(mention_roles), mention_rules.get('role', False) and any(mention_roles),
msg.channel.is_dm msg.channel.is_dm
)): )):
raise StopIteration raise StopIteration
@ -262,14 +274,14 @@ class Bot(object):
content = content.replace('@everyone', '', 1) content = content.replace('@everyone', '', 1)
else: else:
for role in mention_roles: for role in mention_roles:
content = content.replace(role.mention, '', 1) content = content.replace('<@{}>'.format(role), '', 1)
content = content.lstrip() content = content.lstrip()
if self.config.commands_prefix and not content.startswith(self.config.commands_prefix): if prefix and not content.startswith(prefix):
raise StopIteration raise StopIteration
else: else:
content = content[len(self.config.commands_prefix):] content = content[len(prefix):]
if not self.command_matches_re or not self.command_matches_re.match(content): if not self.command_matches_re or not self.command_matches_re.match(content):
raise StopIteration raise StopIteration
@ -283,7 +295,7 @@ class Bot(object):
level = CommandLevels.DEFAULT level = CommandLevels.DEFAULT
if callable(self.config.commands_level_getter): if callable(self.config.commands_level_getter):
level = self.config.commands_level_getter(actor) level = self.config.commands_level_getter(self, actor)
else: else:
if actor.id in self.config.levels: if actor.id in self.config.levels:
level = self.config.levels[actor.id] level = self.config.levels[actor.id]
@ -320,19 +332,24 @@ class Bot(object):
bool bool
whether any commands where successfully triggered by the message whether any commands where successfully triggered by the message
""" """
commands = list(self.get_commands_for_message(msg)) commands = list(self.get_commands_for_message(
self.config.commands_require_mention,
self.config.commands_mention_rules,
self.config.commands_prefix,
msg
))
if len(commands): if not len(commands):
result = False return False
for command, match in commands:
if not self.check_command_permissions(command, msg):
continue
if command.plugin.execute(CommandEvent(command, msg, match)): result = False
result = True for command, match in commands:
return result if not self.check_command_permissions(command, msg):
continue
return False if command.plugin.execute(CommandEvent(command, msg, match)):
result = True
return result
def on_message_create(self, event): def on_message_create(self, event):
if event.message.author.id == self.client.state.me.id: if event.message.author.id == self.client.state.me.id:
@ -356,7 +373,7 @@ class Bot(object):
self.last_message_cache[msg.channel_id] = (msg, triggered) self.last_message_cache[msg.channel_id] = (msg, triggered)
def add_plugin(self, cls, config=None): def add_plugin(self, cls, config=None, ctx=None):
""" """
Adds and loads a plugin, based on its class. Adds and loads a plugin, based on its class.
@ -366,8 +383,12 @@ class Bot(object):
Plugin class to initialize and load. Plugin class to initialize and load.
config : Optional config : Optional
The configuration to load the plugin with. The configuration to load the plugin with.
ctx : Optional[dict]
Context (previous state) to pass the plugin. Usually used along w/
unload.
""" """
if cls.__name__ in self.plugins: if cls.__name__ in self.plugins:
self.log.warning('Attempted to add already added plugin %s', cls.__name__)
raise Exception('Cannot add already added plugin: {}'.format(cls.__name__)) raise Exception('Cannot add already added plugin: {}'.format(cls.__name__))
if not config: if not config:
@ -376,9 +397,10 @@ class Bot(object):
else: else:
config = self.load_plugin_config(cls) config = self.load_plugin_config(cls)
self.plugins[cls.__name__] = cls(self, config) self.ctx['plugin'] = self.plugins[cls.__name__] = cls(self, config)
self.plugins[cls.__name__].load() self.plugins[cls.__name__].load(ctx or {})
self.recompute() self.recompute()
self.ctx.drop()
def rmv_plugin(self, cls): def rmv_plugin(self, cls):
""" """
@ -392,9 +414,11 @@ class Bot(object):
if cls.__name__ not in self.plugins: if cls.__name__ not in self.plugins:
raise Exception('Cannot remove non-existant plugin: {}'.format(cls.__name__)) raise Exception('Cannot remove non-existant plugin: {}'.format(cls.__name__))
self.plugins[cls.__name__].unload() ctx = {}
self.plugins[cls.__name__].unload(ctx)
del self.plugins[cls.__name__] del self.plugins[cls.__name__]
self.recompute() self.recompute()
return ctx
def reload_plugin(self, cls): def reload_plugin(self, cls):
""" """
@ -402,13 +426,13 @@ class Bot(object):
""" """
config = self.plugins[cls.__name__].config config = self.plugins[cls.__name__].config
self.rmv_plugin(cls) ctx = self.rmv_plugin(cls)
module = reload_module(inspect.getmodule(cls)) module = reload_module(inspect.getmodule(cls))
self.add_plugin(getattr(module, cls.__name__), config) self.add_plugin(getattr(module, cls.__name__), config, ctx)
def run_forever(self): def run_forever(self):
""" """
Runs this bots core loop forever Runs this bots core loop forever.
""" """
self.client.run_forever() self.client.run_forever()
@ -416,12 +440,14 @@ class Bot(object):
""" """
Adds and loads a plugin, based on its module path. Adds and loads a plugin, based on its module path.
""" """
self.log.info('Adding plugin module at path "%s"', path)
mod = importlib.import_module(path) mod = importlib.import_module(path)
loaded = False loaded = False
for entry in map(lambda i: getattr(mod, i), dir(mod)): for entry in map(lambda i: getattr(mod, i), dir(mod)):
if inspect.isclass(entry) and issubclass(entry, Plugin) and not entry == Plugin: if inspect.isclass(entry) and issubclass(entry, Plugin) and not entry == Plugin:
if getattr(entry, '_shallow', False) and Plugin in entry.__bases__:
continue
loaded = True loaded = True
self.add_plugin(entry, config) self.add_plugin(entry, config)
@ -430,23 +456,24 @@ class Bot(object):
def load_plugin_config(self, cls): def load_plugin_config(self, cls):
name = cls.__name__.lower() name = cls.__name__.lower()
if name.startswith('plugin'): if name.endswith('plugin'):
name = name[6:] name = name[:-6]
path = os.path.join( path = os.path.join(
self.config.plugin_config_dir, name) + '.' + self.config.plugin_config_format self.config.plugin_config_dir, name) + '.' + self.config.plugin_config_format
if not os.path.exists(path): data = {}
if hasattr(cls, 'config_cls'): if name in self.config.plugin_config:
return cls.config_cls() data = self.config.plugin_config[name]
return
with open(path, 'r') as f: if os.path.exists(path):
data = Serializer.loads(self.config.plugin_config_format, f.read()) with open(path, 'r') as f:
data.update(Serializer.loads(self.config.plugin_config_format, f.read()))
if hasattr(cls, 'config_cls'): if hasattr(cls, 'config_cls'):
inst = cls.config_cls() inst = cls.config_cls()
inst.update(data) if data:
inst.update(data)
return inst return inst
return data return data

140
disco/bot/command.py

@ -5,9 +5,11 @@ from holster.enum import Enum
from disco.bot.parser import ArgumentSet, ArgumentError from disco.bot.parser import ArgumentSet, ArgumentError
from disco.util.functional import cached_property from disco.util.functional import cached_property
REGEX_FMT = '({})' ARGS_REGEX = '(?: ((?:\n|.)*)$|$)'
ARGS_REGEX = '( (.*)$|$)'
MENTION_RE = re.compile('<@!?([0-9]+)>') USER_MENTION_RE = re.compile('<@!?([0-9]+)>')
ROLE_MENTION_RE = re.compile('<@&([0-9]+)>')
CHANNEL_MENTION_RE = re.compile('<#([0-9]+)>')
CommandLevels = Enum( CommandLevels = Enum(
DEFAULT=0, DEFAULT=0,
@ -42,34 +44,52 @@ class CommandEvent(object):
self.command = command self.command = command
self.msg = msg self.msg = msg
self.match = match self.match = match
self.name = self.match.group(1) self.name = self.match.group(0)
self.args = [i for i in self.match.group(2).strip().split(' ') if i] self.args = []
if self.match.group(1):
self.args = [i for i in self.match.group(1).strip().split(' ') if i]
@property
def codeblock(self):
if '`' not in self.msg.content:
return ' '.join(self.args)
_, src = self.msg.content.split('`', 1)
src = '`' + src
if src.startswith('```') and src.endswith('```'):
src = src[3:-3]
elif src.startswith('`') and src.endswith('`'):
src = src[1:-1]
return src
@cached_property @cached_property
def member(self): def member(self):
""" """
Guild member (if relevant) for the user that created the message Guild member (if relevant) for the user that created the message.
""" """
return self.guild.get_member(self.author) return self.guild.get_member(self.author)
@property @property
def channel(self): def channel(self):
""" """
Channel the message was created in Channel the message was created in.
""" """
return self.msg.channel return self.msg.channel
@property @property
def guild(self): def guild(self):
""" """
Guild (if relevant) the message was created in Guild (if relevant) the message was created in.
""" """
return self.msg.guild return self.msg.guild
@property @property
def author(self): def author(self):
""" """
Author of the message Author of the message.
""" """
return self.msg.author return self.msg.author
@ -107,61 +127,106 @@ class Command(object):
self.plugin = plugin self.plugin = plugin
self.func = func self.func = func
self.triggers = [trigger] self.triggers = [trigger]
self.dispatch_func = None
self.raw_args = None
self.args = None
self.level = None
self.group = None
self.is_regex = None
self.oob = False
self.context = {}
self.metadata = {}
self.update(*args, **kwargs) self.update(*args, **kwargs)
def update(self, args=None, level=None, aliases=None, group=None, is_regex=None): def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def get_docstring(self):
return (self.func.__doc__ or '').format(**self.context)
def update(self, args=None, level=None, aliases=None, group=None, is_regex=None, oob=False, context=None, **kwargs):
self.triggers += aliases or [] self.triggers += aliases or []
def resolve_role(ctx, id): def resolve_role(ctx, rid):
return ctx.msg.guild.roles.get(id) return ctx.msg.guild.roles.get(rid)
def resolve_user(ctx, uid):
if isinstance(uid, int):
if uid in ctx.msg.mentions:
return ctx.msg.mentions.get(uid)
else:
return ctx.msg.client.state.users.get(uid)
else:
return ctx.msg.client.state.users.select_one(username=uid[0], discriminator=uid[1])
def resolve_channel(ctx, cid):
if isinstance(cid, (int, long)):
return ctx.msg.guild.channels.get(cid)
else:
return ctx.msg.guild.channels.select_one(name=cid)
def resolve_user(ctx, id): def resolve_guild(ctx, gid):
return ctx.msg.mentions.get(id) return ctx.msg.client.state.guilds.get(gid)
self.raw_args = args
self.args = ArgumentSet.from_string(args or '', { self.args = ArgumentSet.from_string(args or '', {
'mention': self.mention_type([resolve_role, resolve_user]), 'user': self.mention_type([resolve_user], USER_MENTION_RE, user=True),
'user': self.mention_type([resolve_user], force=True), 'role': self.mention_type([resolve_role], ROLE_MENTION_RE),
'role': self.mention_type([resolve_role], force=True), 'channel': self.mention_type([resolve_channel], CHANNEL_MENTION_RE, allow_plain=True),
'guild': self.mention_type([resolve_guild]),
}) })
self.level = level self.level = level
self.group = group self.group = group
self.is_regex = is_regex self.is_regex = is_regex
self.oob = oob
self.context = context or {}
self.metadata = kwargs
@staticmethod @staticmethod
def mention_type(getters, force=False): def mention_type(getters, reg=None, user=False, allow_plain=False):
def _f(ctx, i): def _f(ctx, raw):
res = MENTION_RE.match(i) if raw.isdigit():
if not res: resolved = int(raw)
raise TypeError('Invalid mention: {}'.format(i)) elif user and raw.count('#') == 1 and raw.split('#')[-1].isdigit():
username, discrim = raw.split('#')
id = int(res.group(1)) resolved = (username, int(discrim))
elif reg:
res = reg.match(raw)
if res:
resolved = int(res.group(1))
else:
if allow_plain:
resolved = raw
else:
raise TypeError('Invalid mention: {}'.format(raw))
else:
raise TypeError('Invalid mention: {}'.format(raw))
for getter in getters: for getter in getters:
obj = getter(ctx, id) obj = getter(ctx, resolved)
if obj: if obj:
return obj return obj
if force: raise TypeError('Cannot resolve mention: {}'.format(raw))
raise TypeError('Cannot resolve mention: {}'.format(id))
return id
return _f return _f
@cached_property @cached_property
def compiled_regex(self): def compiled_regex(self):
""" """
A compiled version of this command's regex A compiled version of this command's regex.
""" """
return re.compile(self.regex) return re.compile(self.regex, re.I)
@property @property
def regex(self): def regex(self):
""" """
The regex string that defines/triggers this command The regex string that defines/triggers this command.
""" """
if self.is_regex: if self.is_regex:
return REGEX_FMT.format('|'.join(self.triggers)) return '|'.join(self.triggers)
else: else:
group = '' group = ''
if self.group: if self.group:
@ -169,7 +234,7 @@ class Command(object):
group = '{}(?:\w+)? '.format(self.plugin.bot.group_abbrev.get(self.group)) group = '{}(?:\w+)? '.format(self.plugin.bot.group_abbrev.get(self.group))
else: else:
group = self.group + ' ' group = self.group + ' '
return REGEX_FMT.format('|'.join(['^' + group + trigger for trigger in self.triggers]) + ARGS_REGEX) return '^{}(?:{})'.format(group, '|'.join(self.triggers)) + ARGS_REGEX
def execute(self, event): def execute(self, event):
""" """
@ -189,8 +254,11 @@ class Command(object):
)) ))
try: try:
args = self.args.parse(event.args, ctx=event) parsed_args = self.args.parse(event.args, ctx=event)
except ArgumentError as e: except ArgumentError as e:
raise CommandError(e.message) raise CommandError(e.message)
return self.func(event, *args) kwargs = {}
kwargs.update(self.context)
kwargs.update(parsed_args)
return self.plugin.dispatch('command', self, event, **kwargs)

76
disco/bot/parser.py

@ -2,9 +2,10 @@ import re
import six import six
import copy import copy
# Regex which splits out argument parts # Regex which splits out argument parts
PARTS_RE = re.compile('(\<|\[)((?:\w+|\:|\||\.\.\.| (?:[0-9]+))+)(?:\>|\])') PARTS_RE = re.compile('(\<|\[|\{)((?:\w+|\:|\||\.\.\.| (?:[0-9]+))+)(?:\>|\]|\})')
BOOL_OPTS = {'yes': True, 'no': False, 'true': True, 'False': False, '1': True, '0': False}
# Mapping of types # Mapping of types
TYPE_MAP = { TYPE_MAP = {
@ -14,6 +15,20 @@ TYPE_MAP = {
'snowflake': lambda ctx, data: int(data), 'snowflake': lambda ctx, data: int(data),
} }
try:
import dateparser
TYPE_MAP['duration'] = lambda ctx, data: dateparser.parse(data, settings={'TIMEZONE': 'UTC'})
except ImportError:
pass
def to_bool(ctx, data):
if data in BOOL_OPTS:
return BOOL_OPTS[data]
raise TypeError
TYPE_MAP['bool'] = to_bool
class ArgumentError(Exception): class ArgumentError(Exception):
""" """
@ -41,19 +56,20 @@ class Argument(object):
self.name = None self.name = None
self.count = 1 self.count = 1
self.required = False self.required = False
self.flag = False
self.types = None self.types = None
self.parse(raw) self.parse(raw)
@property @property
def true_count(self): def true_count(self):
""" """
The true number of raw arguments this argument takes The true number of raw arguments this argument takes.
""" """
return self.count or 1 return self.count or 1
def parse(self, raw): def parse(self, raw):
""" """
Attempts to parse arguments from their raw form Attempts to parse arguments from their raw form.
""" """
prefix, part = raw prefix, part = raw
@ -62,23 +78,27 @@ class Argument(object):
else: else:
self.required = False self.required = False
if part.endswith('...'): # Whether this is a flag
part = part[:-3] self.flag = (prefix == '{')
self.count = 0
elif ' ' in part: if not self.flag:
split = part.split(' ', 1) if part.endswith('...'):
part, self.count = split[0], int(split[1]) part = part[:-3]
self.count = 0
elif ' ' in part:
split = part.split(' ', 1)
part, self.count = split[0], int(split[1])
if ':' in part: if ':' in part:
part, typeinfo = part.split(':') part, typeinfo = part.split(':')
self.types = typeinfo.split('|') self.types = typeinfo.split('|')
self.name = part.strip() self.name = part.strip()
class ArgumentSet(object): class ArgumentSet(object):
""" """
A set of :class:`Argument` instances which forms a larger argument specification A set of :class:`Argument` instances which forms a larger argument specification.
Attributes Attributes
---------- ----------
@ -95,7 +115,7 @@ class ArgumentSet(object):
@classmethod @classmethod
def from_string(cls, line, custom_types=None): def from_string(cls, line, custom_types=None):
""" """
Creates a new :class:`ArgumentSet` from a given argument string specification Creates a new :class:`ArgumentSet` from a given argument string specification.
""" """
args = cls(custom_types=custom_types) args = cls(custom_types=custom_types)
@ -131,7 +151,7 @@ class ArgumentSet(object):
def append(self, arg): def append(self, arg):
""" """
Add a new :class:`Argument` to this argument specification/set Add a new :class:`Argument` to this argument specification/set.
""" """
if self.args and not self.args[-1].required and arg.required: if self.args and not self.args[-1].required and arg.required:
raise Exception('Required argument cannot come after an optional argument') raise Exception('Required argument cannot come after an optional argument')
@ -145,9 +165,23 @@ class ArgumentSet(object):
""" """
Parse a string of raw arguments into this argument specification. Parse a string of raw arguments into this argument specification.
""" """
parsed = [] parsed = {}
flags = {i.name: i for i in self.args if i.flag}
if flags:
new_rawargs = []
for offset, raw in enumerate(rawargs):
if raw.startswith('-'):
raw = raw.lstrip('-')
if raw in flags:
parsed[raw] = True
continue
new_rawargs.append(raw)
rawargs = new_rawargs
for index, arg in enumerate(self.args): for index, arg in enumerate((arg for arg in self.args if not arg.flag)):
if not arg.required and index + arg.true_count > len(rawargs): if not arg.required and index + arg.true_count > len(rawargs):
continue continue
@ -171,20 +205,20 @@ class ArgumentSet(object):
if (not arg.types or arg.types == ['str']) and isinstance(raw, list): if (not arg.types or arg.types == ['str']) and isinstance(raw, list):
raw = ' '.join(raw) raw = ' '.join(raw)
parsed.append(raw) parsed[arg.name] = raw
return parsed return parsed
@property @property
def length(self): def length(self):
""" """
The number of arguments in this set/specification The number of arguments in this set/specification.
""" """
return len(self.args) return len(self.args)
@property @property
def required_length(self): def required_length(self):
""" """
The number of required arguments to compile this set/specificaiton The number of required arguments to compile this set/specificaiton.
""" """
return sum([i.true_count for i in self.args if i.required]) return sum([i.true_count for i in self.args if i.required])

187
disco/bot/plugin.py

@ -1,9 +1,11 @@
import six import six
import types
import gevent import gevent
import inspect import inspect
import weakref import weakref
import functools import functools
from gevent.event import AsyncResult
from holster.emitter import Priority from holster.emitter import Priority
from disco.util.logging import LoggingClass from disco.util.logging import LoggingClass
@ -18,8 +20,8 @@ class PluginDeco(object):
Prio = Priority Prio = Priority
# TODO: dont smash class methods # TODO: dont smash class methods
@staticmethod @classmethod
def add_meta_deco(meta): def add_meta_deco(cls, meta):
def deco(f): def deco(f):
if not hasattr(f, 'meta'): if not hasattr(f, 'meta'):
f.meta = [] f.meta = []
@ -40,33 +42,33 @@ class PluginDeco(object):
return deco return deco
@classmethod @classmethod
def listen(cls, event_name, priority=None): def listen(cls, *args, **kwargs):
""" """
Binds the function to listen for a given event name Binds the function to listen for a given event name.
""" """
return cls.add_meta_deco({ return cls.add_meta_deco({
'type': 'listener', 'type': 'listener',
'what': 'event', 'what': 'event',
'desc': event_name, 'args': args,
'priority': priority 'kwargs': kwargs,
}) })
@classmethod @classmethod
def listen_packet(cls, op, priority=None): def listen_packet(cls, *args, **kwargs):
""" """
Binds the function to listen for a given gateway op code Binds the function to listen for a given gateway op code.
""" """
return cls.add_meta_deco({ return cls.add_meta_deco({
'type': 'listener', 'type': 'listener',
'what': 'packet', 'what': 'packet',
'desc': op, 'args': args,
'priority': priority, 'kwargs': kwargs,
}) })
@classmethod @classmethod
def command(cls, *args, **kwargs): def command(cls, *args, **kwargs):
""" """
Creates a new command attached to the function Creates a new command attached to the function.
""" """
return cls.add_meta_deco({ return cls.add_meta_deco({
'type': 'command', 'type': 'command',
@ -77,7 +79,7 @@ class PluginDeco(object):
@classmethod @classmethod
def pre_command(cls): def pre_command(cls):
""" """
Runs a function before a command is triggered Runs a function before a command is triggered.
""" """
return cls.add_meta_deco({ return cls.add_meta_deco({
'type': 'pre_command', 'type': 'pre_command',
@ -86,7 +88,7 @@ class PluginDeco(object):
@classmethod @classmethod
def post_command(cls): def post_command(cls):
""" """
Runs a function after a command is triggered Runs a function after a command is triggered.
""" """
return cls.add_meta_deco({ return cls.add_meta_deco({
'type': 'post_command', 'type': 'post_command',
@ -95,7 +97,7 @@ class PluginDeco(object):
@classmethod @classmethod
def pre_listener(cls): def pre_listener(cls):
""" """
Runs a function before a listener is triggered Runs a function before a listener is triggered.
""" """
return cls.add_meta_deco({ return cls.add_meta_deco({
'type': 'pre_listener', 'type': 'pre_listener',
@ -104,7 +106,7 @@ class PluginDeco(object):
@classmethod @classmethod
def post_listener(cls): def post_listener(cls):
""" """
Runs a function after a listener is triggered Runs a function after a listener is triggered.
""" """
return cls.add_meta_deco({ return cls.add_meta_deco({
'type': 'post_listener', 'type': 'post_listener',
@ -113,7 +115,7 @@ class PluginDeco(object):
@classmethod @classmethod
def schedule(cls, *args, **kwargs): def schedule(cls, *args, **kwargs):
""" """
Runs a function repeatedly, waiting for a specified interval Runs a function repeatedly, waiting for a specified interval.
""" """
return cls.add_meta_deco({ return cls.add_meta_deco({
'type': 'schedule', 'type': 'schedule',
@ -153,46 +155,101 @@ class Plugin(LoggingClass, PluginDeco):
self.storage = bot.storage self.storage = bot.storage
self.config = config self.config = config
# General declartions
self.listeners = []
self.commands = []
self.schedules = {}
self.greenlets = weakref.WeakSet()
self._pre = {}
self._post = {}
# This is an array of all meta functions we sniff at init
self.meta_funcs = []
for name, member in inspect.getmembers(self, predicate=inspect.ismethod):
if hasattr(member, 'meta'):
self.meta_funcs.append(member)
# Unsmash local functions
if hasattr(Plugin, name):
method = types.MethodType(getattr(Plugin, name), self, self.__class__)
setattr(self, name, method)
self.bind_all()
@property @property
def name(self): def name(self):
return self.__class__.__name__ return self.__class__.__name__
def bind_all(self): def bind_all(self):
self.listeners = [] self.listeners = []
self.commands = {} self.commands = []
self.schedules = {} self.schedules = {}
self.greenlets = weakref.WeakSet() self.greenlets = weakref.WeakSet()
self._pre = {'command': [], 'listener': []} self._pre = {'command': [], 'listener': []}
self._post = {'command': [], 'listener': []} self._post = {'command': [], 'listener': []}
# TODO: when handling events/commands we need to track the greenlet in for member in self.meta_funcs:
# the greenlets set so we can termiante long running commands/listeners for meta in member.meta:
# on reload. self.bind_meta(member, meta)
def bind_meta(self, member, meta):
if meta['type'] == 'listener':
self.register_listener(member, meta['what'], *meta['args'], **meta['kwargs'])
elif meta['type'] == 'command':
# meta['kwargs']['update'] = True
self.register_command(member, *meta['args'], **meta['kwargs'])
elif meta['type'] == 'schedule':
self.register_schedule(member, *meta['args'], **meta['kwargs'])
elif meta['type'].startswith('pre_') or meta['type'].startswith('post_'):
when, typ = meta['type'].split('_', 1)
self.register_trigger(typ, when, member)
def handle_exception(self, greenlet, event):
pass
def wait_for_event(self, event_name, **kwargs):
result = AsyncResult()
listener = None
def _event_callback(event):
for k, v in kwargs.items():
if getattr(event, k) != v:
break
else:
listener.remove()
return result.set(event)
for name, member in inspect.getmembers(self, predicate=inspect.ismethod): listener = self.bot.client.events.on(event_name, _event_callback)
if hasattr(member, 'meta'):
for meta in member.meta: return result
if meta['type'] == 'listener':
self.register_listener(member, meta['what'], meta['desc'], meta['priority']) def spawn_wrap(self, spawner, method, *args, **kwargs):
elif meta['type'] == 'command': def wrapped(*args, **kwargs):
meta['kwargs']['update'] = True self.ctx['plugin'] = self
self.register_command(member, *meta['args'], **meta['kwargs']) try:
elif meta['type'] == 'schedule': res = method(*args, **kwargs)
self.register_schedule(member, *meta['args'], **meta['kwargs']) return res
elif meta['type'].startswith('pre_') or meta['type'].startswith('post_'): finally:
when, typ = meta['type'].split('_', 1) self.ctx.drop()
self.register_trigger(typ, when, member)
obj = spawner(wrapped, *args, **kwargs)
def spawn(self, method, *args, **kwargs):
obj = gevent.spawn(method, *args, **kwargs)
self.greenlets.add(obj) self.greenlets.add(obj)
return obj return obj
def spawn(self, *args, **kwargs):
return self.spawn_wrap(gevent.spawn, *args, **kwargs)
def spawn_later(self, delay, *args, **kwargs):
return self.spawn_wrap(functools.partial(gevent.spawn_later, delay), *args, **kwargs)
def execute(self, event): def execute(self, event):
""" """
Executes a CommandEvent this plugin owns Executes a CommandEvent this plugin owns.
""" """
if not event.command.oob:
self.greenlets.add(gevent.getcurrent())
try: try:
return event.command.execute(event) return event.command.execute(event)
except CommandError as e: except CommandError as e:
@ -203,11 +260,18 @@ class Plugin(LoggingClass, PluginDeco):
def register_trigger(self, typ, when, func): def register_trigger(self, typ, when, func):
""" """
Registers a trigger Registers a trigger.
""" """
getattr(self, '_' + when)[typ].append(func) getattr(self, '_' + when)[typ].append(func)
def _dispatch(self, typ, func, event, *args, **kwargs): def dispatch(self, typ, func, event, *args, **kwargs):
# Link the greenlet with our exception handler
gevent.getcurrent().link_exception(lambda g: self.handle_exception(g, event))
# TODO: this is ugly
if typ != 'command':
self.greenlets.add(gevent.getcurrent())
self.ctx['plugin'] = self self.ctx['plugin'] = self
if hasattr(event, 'guild'): if hasattr(event, 'guild'):
@ -218,7 +282,7 @@ class Plugin(LoggingClass, PluginDeco):
self.ctx['user'] = event.author self.ctx['user'] = event.author
for pre in self._pre[typ]: for pre in self._pre[typ]:
event = pre(event, args, kwargs) event = pre(func, event, args, kwargs)
if event is None: if event is None:
return False return False
@ -226,13 +290,13 @@ class Plugin(LoggingClass, PluginDeco):
result = func(event, *args, **kwargs) result = func(event, *args, **kwargs)
for post in self._post[typ]: for post in self._post[typ]:
post(event, args, kwargs, result) post(func, event, args, kwargs, result)
return True return True
def register_listener(self, func, what, desc, priority): def register_listener(self, func, what, *args, **kwargs):
""" """
Registers a listener Registers a listener.
Parameters Parameters
---------- ----------
@ -242,17 +306,13 @@ class Plugin(LoggingClass, PluginDeco):
The function to be registered. The function to be registered.
desc desc
The descriptor of the event/packet. The descriptor of the event/packet.
priority : Priority
The priority of this listener.
""" """
func = functools.partial(self._dispatch, 'listener', func) args = list(args) + [functools.partial(self.dispatch, 'listener', func)]
priority = priority or Priority.NONE
if what == 'event': if what == 'event':
li = self.bot.client.events.on(desc, func, priority=priority) li = self.bot.client.events.on(*args, **kwargs)
elif what == 'packet': elif what == 'packet':
li = self.bot.client.packets.on(desc, func, priority=priority) li = self.bot.client.packets.on(*args, **kwargs)
else: else:
raise Exception('Invalid listener what: {}'.format(what)) raise Exception('Invalid listener what: {}'.format(what))
@ -260,7 +320,7 @@ class Plugin(LoggingClass, PluginDeco):
def register_command(self, func, *args, **kwargs): def register_command(self, func, *args, **kwargs):
""" """
Registers a command Registers a command.
Parameters Parameters
---------- ----------
@ -272,11 +332,7 @@ class Plugin(LoggingClass, PluginDeco):
Keyword arguments to pass onto the :class:`disco.bot.command.Command` Keyword arguments to pass onto the :class:`disco.bot.command.Command`
object. object.
""" """
if kwargs.pop('update', False) and func.__name__ in self.commands: self.commands.append(Command(self, func, *args, **kwargs))
self.commands[func.__name__].update(*args, **kwargs)
else:
wrapped = functools.partial(self._dispatch, 'command', func)
self.commands[func.__name__] = Command(self, wrapped, *args, **kwargs)
def register_schedule(self, func, interval, repeat=True, init=True): def register_schedule(self, func, interval, repeat=True, init=True):
""" """
@ -289,8 +345,13 @@ class Plugin(LoggingClass, PluginDeco):
The function to be registered. The function to be registered.
interval : int interval : int
Interval (in seconds) to repeat the function on. Interval (in seconds) to repeat the function on.
repeat : bool
Whether this schedule is repeating (or one time).
init : bool
Whether to run this schedule once immediatly, or wait for the first
scheduled iteration.
""" """
def repeat(): def repeat_func():
if init: if init:
func() func()
@ -300,17 +361,17 @@ class Plugin(LoggingClass, PluginDeco):
if not repeat: if not repeat:
break break
self.schedules[func.__name__] = self.spawn(repeat) self.schedules[func.__name__] = self.spawn(repeat_func)
def load(self): def load(self, ctx):
""" """
Called when the plugin is loaded Called when the plugin is loaded.
""" """
self.bind_all() pass
def unload(self): def unload(self, ctx):
""" """
Called when the plugin is unloaded Called when the plugin is unloaded.
""" """
for greenlet in self.greenlets: for greenlet in self.greenlets:
greenlet.kill() greenlet.kill()

1
disco/bot/providers/disk.py

@ -13,6 +13,7 @@ class DiskProvider(BaseProvider):
self.fsync = config.get('fsync', False) self.fsync = config.get('fsync', False)
self.fsync_changes = config.get('fsync_changes', 1) self.fsync_changes = config.get('fsync_changes', 1)
self.autosave_task = None
self.change_count = 0 self.change_count = 0
def autosave_loop(self, interval): def autosave_loop(self, interval):

25
disco/bot/providers/redis.py

@ -10,32 +10,39 @@ from .base import BaseProvider, SEP_SENTINEL
class RedisProvider(BaseProvider): class RedisProvider(BaseProvider):
def __init__(self, config): def __init__(self, config):
self.config = config super(RedisProvider, self).__init__(config)
self.format = config.get('format', 'pickle')
self.conn = None
def load(self): def load(self):
self.redis = redis.Redis( self.conn = redis.Redis(
host=self.config.get('host', 'localhost'), host=self.config.get('host', 'localhost'),
port=self.config.get('port', 6379), port=self.config.get('port', 6379),
db=self.config.get('db', 0)) db=self.config.get('db', 0))
def exists(self, key): def exists(self, key):
return self.db.exists(key) return self.conn.exists(key)
def keys(self, other): def keys(self, other):
count = other.count(SEP_SENTINEL) + 1 count = other.count(SEP_SENTINEL) + 1
for key in self.db.scan_iter(u'{}*'.format(other)): for key in self.conn.scan_iter(u'{}*'.format(other)):
key = key.decode('utf-8')
if key.count(SEP_SENTINEL) == count: if key.count(SEP_SENTINEL) == count:
yield key yield key
def get_many(self, keys): def get_many(self, keys):
for key, value in izip(keys, self.db.mget(keys)): keys = list(keys)
if not len(keys):
raise StopIteration
for key, value in izip(keys, self.conn.mget(keys)):
yield (key, Serializer.loads(self.format, value)) yield (key, Serializer.loads(self.format, value))
def get(self, key): def get(self, key):
return Serializer.loads(self.format, self.db.get(key)) return Serializer.loads(self.format, self.conn.get(key))
def set(self, key, value): def set(self, key, value):
self.db.set(key, Serializer.dumps(self.format, value)) self.conn.set(key, Serializer.dumps(self.format, value))
def delete(self, key, value): def delete(self, key):
self.db.delete(key) self.conn.delete(key)

6
disco/bot/providers/rocksdb.py

@ -12,11 +12,13 @@ from .base import BaseProvider, SEP_SENTINEL
class RocksDBProvider(BaseProvider): class RocksDBProvider(BaseProvider):
def __init__(self, config): def __init__(self, config):
self.config = config super(RocksDBProvider, self).__init__(config)
self.format = config.get('format', 'pickle') self.format = config.get('format', 'pickle')
self.path = config.get('path', 'storage.db') self.path = config.get('path', 'storage.db')
self.db = None
def k(self, k): @staticmethod
def k(k):
return bytes(k) if six.PY3 else str(k.encode('utf-8')) return bytes(k) if six.PY3 else str(k.encode('utf-8'))
def load(self): def load(self):

17
disco/cli.py

@ -18,14 +18,13 @@ parser.add_argument('--config', help='Configuration file', default='config.yaml'
parser.add_argument('--token', help='Bot Authentication Token', default=None) parser.add_argument('--token', help='Bot Authentication Token', default=None)
parser.add_argument('--shard-count', help='Total number of shards', default=None) parser.add_argument('--shard-count', help='Total number of shards', default=None)
parser.add_argument('--shard-id', help='Current shard number/id', default=None) parser.add_argument('--shard-id', help='Current shard number/id', default=None)
parser.add_argument('--shard-auto', help='Automatically run all shards', action='store_true', default=False)
parser.add_argument('--manhole', action='store_true', help='Enable the manhole', default=None) parser.add_argument('--manhole', action='store_true', help='Enable the manhole', default=None)
parser.add_argument('--manhole-bind', help='host:port for the manhole to bind too', default=None) parser.add_argument('--manhole-bind', help='host:port for the manhole to bind too', default=None)
parser.add_argument('--encoder', help='encoder for gateway data', default=None) parser.add_argument('--encoder', help='encoder for gateway data', default=None)
parser.add_argument('--run-bot', help='run a disco bot on this client', action='store_true', default=False) parser.add_argument('--run-bot', help='run a disco bot on this client', action='store_true', default=False)
parser.add_argument('--plugin', help='load plugins into the bot', nargs='*', default=[]) parser.add_argument('--plugin', help='load plugins into the bot', nargs='*', default=[])
logging.basicConfig(level=logging.INFO)
def disco_main(run=False): def disco_main(run=False):
""" """
@ -42,6 +41,7 @@ def disco_main(run=False):
from disco.client import Client, ClientConfig from disco.client import Client, ClientConfig
from disco.bot import Bot, BotConfig from disco.bot import Bot, BotConfig
from disco.util.token import is_valid_token from disco.util.token import is_valid_token
from disco.util.logging import setup_logging
if os.path.exists(args.config): if os.path.exists(args.config):
config = ClientConfig.from_file(args.config) config = ClientConfig.from_file(args.config)
@ -56,12 +56,23 @@ def disco_main(run=False):
print('Invalid token passed') print('Invalid token passed')
return return
if args.shard_auto:
from disco.gateway.sharder import AutoSharder
AutoSharder(config).run()
return
# TODO: make configurable
setup_logging(level=logging.INFO)
client = Client(config) client = Client(config)
bot = None bot = None
if args.run_bot or hasattr(config, 'bot'): if args.run_bot or hasattr(config, 'bot'):
bot_config = BotConfig(config.bot) if hasattr(config, 'bot') else BotConfig() bot_config = BotConfig(config.bot) if hasattr(config, 'bot') else BotConfig()
bot_config.plugins += args.plugin if not hasattr(bot_config, 'plugins'):
bot_config.plugins = args.plugin
else:
bot_config.plugins += args.plugin
bot = Bot(client, bot_config) bot = Bot(client, bot_config)
if run: if run:

45
disco/client.py

@ -1,3 +1,4 @@
import time
import gevent import gevent
from holster.emitter import Emitter from holster.emitter import Emitter
@ -5,24 +6,28 @@ from holster.emitter import Emitter
from disco.state import State, StateConfig from disco.state import State, StateConfig
from disco.api.client import APIClient from disco.api.client import APIClient
from disco.gateway.client import GatewayClient from disco.gateway.client import GatewayClient
from disco.gateway.packets import OPCode
from disco.types.user import Status, Game
from disco.util.config import Config from disco.util.config import Config
from disco.util.logging import LoggingClass from disco.util.logging import LoggingClass
from disco.util.backdoor import DiscoBackdoorServer from disco.util.backdoor import DiscoBackdoorServer
class ClientConfig(LoggingClass, Config): class ClientConfig(Config):
""" """
Configuration for the :class:`Client`. Configuration for the :class:`Client`.
Attributes Attributes
---------- ----------
token : str token : str
Discord authentication token, ca be validated using the Discord authentication token, can be validated using the
:func:`disco.util.token.is_valid_token` function. :func:`disco.util.token.is_valid_token` function.
shard_id : int shard_id : int
The shard ID for the current client instance. The shard ID for the current client instance.
shard_count : int shard_count : int
The total count of shards running. The total count of shards running.
max_reconnects : int
The maximum number of connection retries to make before giving up (0 = never give up).
manhole_enable : bool manhole_enable : bool
Whether to enable the manhole (e.g. console backdoor server) utility. Whether to enable the manhole (e.g. console backdoor server) utility.
manhole_bind : tuple(str, int) manhole_bind : tuple(str, int)
@ -36,14 +41,15 @@ class ClientConfig(LoggingClass, Config):
token = "" token = ""
shard_id = 0 shard_id = 0
shard_count = 1 shard_count = 1
max_reconnects = 5
manhole_enable = True manhole_enable = False
manhole_bind = ('127.0.0.1', 8484) manhole_bind = ('127.0.0.1', 8484)
encoder = 'json' encoder = 'json'
class Client(object): class Client(LoggingClass):
""" """
Class representing the base entry point that should be used in almost all Class representing the base entry point that should be used in almost all
implementation cases. This class wraps the functionality of both the REST API implementation cases. This class wraps the functionality of both the REST API
@ -82,8 +88,8 @@ class Client(object):
self.events = Emitter(gevent.spawn) self.events = Emitter(gevent.spawn)
self.packets = Emitter(gevent.spawn) self.packets = Emitter(gevent.spawn)
self.api = APIClient(self) self.api = APIClient(self.config.token, self)
self.gw = GatewayClient(self, self.config.encoder) self.gw = GatewayClient(self, self.config.max_reconnects, self.config.encoder)
self.state = State(self, StateConfig(self.config.get('state', {}))) self.state = State(self, StateConfig(self.config.get('state', {})))
if self.config.manhole_enable: if self.config.manhole_enable:
@ -95,18 +101,37 @@ class Client(object):
} }
self.manhole = DiscoBackdoorServer(self.config.manhole_bind, self.manhole = DiscoBackdoorServer(self.config.manhole_bind,
banner='Disco Manhole', banner='Disco Manhole',
localf=lambda: self.manhole_locals) localf=lambda: self.manhole_locals)
self.manhole.start() self.manhole.start()
def update_presence(self, game=None, status=None, afk=False, since=0.0):
if game and not isinstance(game, Game):
raise TypeError('Game must be a Game model')
if status is Status.IDLE and not since:
since = int(time.time() * 1000)
payload = {
'afk': afk,
'since': since,
'status': status.value.lower(),
'game': None,
}
if game:
payload['game'] = game.to_dict()
self.gw.send(OPCode.STATUS_UPDATE, payload)
def run(self): def run(self):
""" """
Run the client (e.g. the :class:`GatewayClient`) in a new greenlet Run the client (e.g. the :class:`GatewayClient`) in a new greenlet.
""" """
return gevent.spawn(self.gw.run) return gevent.spawn(self.gw.run)
def run_forever(self): def run_forever(self):
""" """
Run the client (e.g. the :class:`GatewayClient`) in the current greenlet Run the client (e.g. the :class:`GatewayClient`) in the current greenlet.
""" """
return self.gw.run() return self.gw.run()

52
disco/gateway/client.py

@ -15,16 +15,21 @@ TEN_MEGABYTES = 10490000
class GatewayClient(LoggingClass): class GatewayClient(LoggingClass):
GATEWAY_VERSION = 6 GATEWAY_VERSION = 6
MAX_RECONNECTS = 5
def __init__(self, client, encoder='json'): def __init__(self, client, max_reconnects=5, encoder='json', ipc=None):
super(GatewayClient, self).__init__() super(GatewayClient, self).__init__()
self.client = client self.client = client
self.max_reconnects = max_reconnects
self.encoder = ENCODERS[encoder] self.encoder = ENCODERS[encoder]
self.events = client.events self.events = client.events
self.packets = client.packets self.packets = client.packets
# IPC for shards
if ipc:
self.shards = ipc.get_shards()
self.ipc = ipc
# Its actually 60, 120 but lets give ourselves a buffer # Its actually 60, 120 but lets give ourselves a buffer
self.limiter = SimpleLimiter(60, 130) self.limiter = SimpleLimiter(60, 130)
@ -37,6 +42,7 @@ class GatewayClient(LoggingClass):
# Bind to ready payload # Bind to ready payload
self.events.on('Ready', self.on_ready) self.events.on('Ready', self.on_ready)
self.events.on('Resumed', self.on_resumed)
# Websocket connection # Websocket connection
self.ws = None self.ws = None
@ -76,15 +82,15 @@ class GatewayClient(LoggingClass):
self.log.debug('Dispatching %s', obj.__class__.__name__) self.log.debug('Dispatching %s', obj.__class__.__name__)
self.client.events.emit(obj.__class__.__name__, obj) self.client.events.emit(obj.__class__.__name__, obj)
def handle_heartbeat(self, packet): def handle_heartbeat(self, _):
self._send(OPCode.HEARTBEAT, self.seq) self._send(OPCode.HEARTBEAT, self.seq)
def handle_reconnect(self, packet): def handle_reconnect(self, _):
self.log.warning('Received RECONNECT request, forcing a fresh reconnect') self.log.warning('Received RECONNECT request, forcing a fresh reconnect')
self.session_id = None self.session_id = None
self.ws.close() self.ws.close()
def handle_invalid_session(self, packet): def handle_invalid_session(self, _):
self.log.warning('Recieved INVALID_SESSION, forcing a fresh reconnect') self.log.warning('Recieved INVALID_SESSION, forcing a fresh reconnect')
self.session_id = None self.session_id = None
self.ws.close() self.ws.close()
@ -98,14 +104,21 @@ class GatewayClient(LoggingClass):
self.session_id = ready.session_id self.session_id = ready.session_id
self.reconnects = 0 self.reconnects = 0
def connect_and_run(self): def on_resumed(self, _):
if not self._cached_gateway_url: self.log.info('Recieved RESUMED')
self._cached_gateway_url = self.client.api.gateway( self.reconnects = 0
version=self.GATEWAY_VERSION,
encoding=self.encoder.TYPE) def connect_and_run(self, gateway_url=None):
if not gateway_url:
if not self._cached_gateway_url:
self._cached_gateway_url = self.client.api.gateway_get()['url']
self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url) gateway_url = self._cached_gateway_url
self.ws = Websocket(self._cached_gateway_url)
gateway_url += '?v={}&encoding={}'.format(self.GATEWAY_VERSION, self.encoder.TYPE)
self.log.info('Opening websocket connection to URL `%s`', gateway_url)
self.ws = Websocket(gateway_url)
self.ws.emitter.on('on_open', self.on_open) self.ws.emitter.on('on_open', self.on_open)
self.ws.emitter.on('on_error', self.on_error) self.ws.emitter.on('on_error', self.on_error)
self.ws.emitter.on('on_close', self.on_close) self.ws.emitter.on('on_close', self.on_close)
@ -153,8 +166,8 @@ class GatewayClient(LoggingClass):
'compress': True, 'compress': True,
'large_threshold': 250, 'large_threshold': 250,
'shard': [ 'shard': [
self.client.config.shard_id, int(self.client.config.shard_id),
self.client.config.shard_count, int(self.client.config.shard_count),
], ],
'properties': { 'properties': {
'$os': 'linux', '$os': 'linux',
@ -165,15 +178,22 @@ class GatewayClient(LoggingClass):
}) })
def on_close(self, code, reason): def on_close(self, code, reason):
# Kill heartbeater, a reconnect/resume will trigger a HELLO which will
# respawn it
if self._heartbeat_task:
self._heartbeat_task.kill()
# If we're quitting, just break out of here
if self.shutting_down: if self.shutting_down:
self.log.info('WS Closed: shutting down') self.log.info('WS Closed: shutting down')
return return
# Track reconnect attempts
self.reconnects += 1 self.reconnects += 1
self.log.info('WS Closed: [%s] %s (%s)', code, reason, self.reconnects) self.log.info('WS Closed: [%s] %s (%s)', code, reason, self.reconnects)
if self.MAX_RECONNECTS and self.reconnects > self.MAX_RECONNECTS: if self.max_reconnects and self.reconnects > self.max_reconnects:
raise Exception('Failed to reconect after {} attempts, giving up'.format(self.MAX_RECONNECTS)) raise Exception('Failed to reconnect after {} attempts, giving up'.format(self.max_reconnects))
# Don't resume for these error codes # Don't resume for these error codes
if code and 4000 <= code <= 4010: if code and 4000 <= code <= 4010:

4
disco/gateway/encoding/base.py

@ -1,7 +1,9 @@
from websocket import ABNF from websocket import ABNF
from holster.interface import Interface
class BaseEncoder(object):
class BaseEncoder(Interface):
TYPE = None TYPE = None
OPCODE = ABNF.OPCODE_TEXT OPCODE = ABNF.OPCODE_TEXT

2
disco/gateway/encoding/json.py

@ -1,7 +1,5 @@
from __future__ import absolute_import, print_function from __future__ import absolute_import, print_function
import six
try: try:
import ujson as json import ujson as json
except ImportError: except ImportError:

424
disco/gateway/events.py

@ -4,20 +4,20 @@ import inflection
import six import six
from disco.types.user import User, Presence from disco.types.user import User, Presence
from disco.types.channel import Channel from disco.types.channel import Channel, PermissionOverwrite
from disco.types.message import Message from disco.types.message import Message, MessageReactionEmoji
from disco.types.voice import VoiceState from disco.types.voice import VoiceState
from disco.types.guild import Guild, GuildMember, Role from disco.types.guild import Guild, GuildMember, Role, GuildEmoji
from disco.types.base import Model, ModelMeta, Field, snowflake, listof, lazy_datetime from disco.types.base import Model, ModelMeta, Field, ListField, AutoDictField, snowflake, datetime
# Mapping of discords event name to our event classes # Mapping of discords event name to our event classes
EVENTS_MAP = {} EVENTS_MAP = {}
class GatewayEventMeta(ModelMeta): class GatewayEventMeta(ModelMeta):
def __new__(cls, name, parents, dct): def __new__(mcs, name, parents, dct):
obj = super(GatewayEventMeta, cls).__new__(cls, name, parents, dct) obj = super(GatewayEventMeta, mcs).__new__(mcs, name, parents, dct)
if name != 'GatewayEvent': if name != 'GatewayEvent':
EVENTS_MAP[inflection.underscore(name).upper()] = obj EVENTS_MAP[inflection.underscore(name).upper()] = obj
@ -64,22 +64,21 @@ class GatewayEvent(six.with_metaclass(GatewayEventMeta, Model)):
return cls(obj, client) return cls(obj, client)
def __getattr__(self, name): def __getattr__(self, name):
if hasattr(self, '_wraps_model'): if hasattr(self, '_proxy'):
modname, _ = self._wraps_model return getattr(getattr(self, self._proxy), name)
if hasattr(self, modname) and hasattr(getattr(self, modname), name): return object.__getattribute__(self, name)
return getattr(getattr(self, modname), name)
raise AttributeError(name)
def debug(func=None): def debug(func=None, match=None):
def deco(cls): def deco(cls):
old_init = cls.__init__ old_init = cls.__init__
def new_init(self, obj, *args, **kwargs): def new_init(self, obj, *args, **kwargs):
if func: if not match or match(obj):
print(func(obj)) if func:
else: print(func(obj))
print(obj) else:
print(obj)
old_init(self, obj, *args, **kwargs) old_init(self, obj, *args, **kwargs)
@ -93,8 +92,16 @@ def wraps_model(model, alias=None):
def deco(cls): def deco(cls):
cls._fields[alias] = Field(model) cls._fields[alias] = Field(model)
cls._fields[alias].set_name(alias) cls._fields[alias].name = alias
cls._wraps_model = (alias, model) cls._wraps_model = (alias, model)
cls._proxy = alias
return cls
return deco
def proxy(field):
def deco(cls):
cls._proxy = field
return cls return cls
return deco return deco
@ -103,49 +110,102 @@ class Ready(GatewayEvent):
""" """
Sent after the initial gateway handshake is complete. Contains data required Sent after the initial gateway handshake is complete. Contains data required
for bootstrapping the client's states. for bootstrapping the client's states.
Attributes
-----
version : int
The gateway version.
session_id : str
The session ID.
user : :class:`disco.types.user.User`
The user object for the authed account.
guilds : list[:class:`disco.types.guild.Guild`
All guilds this account is a member of. These are shallow guild objects.
private_channels list[:class:`disco.types.channel.Channel`]
All private channels (DMs) open for this account.
""" """
version = Field(int, alias='v') version = Field(int, alias='v')
session_id = Field(str) session_id = Field(str)
user = Field(User) user = Field(User)
guilds = Field(listof(Guild)) guilds = ListField(Guild)
private_channels = Field(listof(Channel)) private_channels = ListField(Channel)
trace = ListField(str, alias='_trace')
class Resumed(GatewayEvent): class Resumed(GatewayEvent):
""" """
Sent after a resume completes. Sent after a resume completes.
""" """
pass trace = ListField(str, alias='_trace')
@wraps_model(Guild) @wraps_model(Guild)
class GuildCreate(GatewayEvent): class GuildCreate(GatewayEvent):
""" """
Sent when a guild is created, or becomes available. Sent when a guild is joined, or becomes available.
Attributes
-----
guild : :class:`disco.types.guild.Guild`
The guild being created (e.g. joined)
unavailable : bool
If false, this guild is coming online from a previously unavailable state,
and if None, this is a normal guild join event.
""" """
unavailable = Field(bool) unavailable = Field(bool)
@property
def created(self):
"""
Shortcut property which is true when we actually joined the guild.
"""
return self.unavailable is None
@wraps_model(Guild) @wraps_model(Guild)
class GuildUpdate(GatewayEvent): class GuildUpdate(GatewayEvent):
""" """
Sent when a guild is updated. Sent when a guild is updated.
Attributes
-----
guild : :class:`disco.types.guild.Guild`
The updated guild object.
""" """
pass
class GuildDelete(GatewayEvent): class GuildDelete(GatewayEvent):
""" """
Sent when a guild is deleted, or becomes unavailable. Sent when a guild is deleted, left, or becomes unavailable.
Attributes
-----
id : snowflake
The ID of the guild being deleted.
unavailable : bool
If true, this guild is becoming unavailable, if None this is a normal
guild leave event.
""" """
id = Field(snowflake) id = Field(snowflake)
unavailable = Field(bool) unavailable = Field(bool)
@property
def deleted(self):
"""
Shortcut property which is true when we actually have left the guild.
"""
return self.unavailable is None
@wraps_model(Channel) @wraps_model(Channel)
class ChannelCreate(GatewayEvent): class ChannelCreate(GatewayEvent):
""" """
Sent when a channel is created. Sent when a channel is created.
Attributes
-----
channel : :class:`disco.types.channel.Channel`
The channel which was created.
""" """
@ -153,115 +213,236 @@ class ChannelCreate(GatewayEvent):
class ChannelUpdate(ChannelCreate): class ChannelUpdate(ChannelCreate):
""" """
Sent when a channel is updated. Sent when a channel is updated.
Attributes
-----
channel : :class:`disco.types.channel.Channel`
The channel which was updated.
""" """
pass overwrites = AutoDictField(PermissionOverwrite, 'id', alias='permission_overwrites')
@wraps_model(Channel) @wraps_model(Channel)
class ChannelDelete(ChannelCreate): class ChannelDelete(ChannelCreate):
""" """
Sent when a channel is deleted. Sent when a channel is deleted.
Attributes
-----
channel : :class:`disco.types.channel.Channel`
The channel being deleted.
""" """
pass
class ChannelPinsUpdate(GatewayEvent): class ChannelPinsUpdate(GatewayEvent):
""" """
Sent when a channel's pins are updated. Sent when a channel's pins are updated.
Attributes
-----
channel_id : snowflake
ID of the channel where pins where updated.
last_pin_timestap : datetime
The time the last message was pinned.
""" """
channel_id = Field(snowflake) channel_id = Field(snowflake)
last_pin_timestamp = Field(lazy_datetime) last_pin_timestamp = Field(datetime)
@wraps_model(User) @proxy(User)
class GuildBanAdd(GatewayEvent): class GuildBanAdd(GatewayEvent):
""" """
Sent when a user is banned from a guild. Sent when a user is banned from a guild.
Attributes
-----
guild_id : snowflake
The ID of the guild the user is being banned from.
user : :class:`disco.types.user.User`
The user being banned from the guild.
""" """
pass guild_id = Field(snowflake)
user = Field(User)
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@wraps_model(User)
@proxy(User)
class GuildBanRemove(GuildBanAdd): class GuildBanRemove(GuildBanAdd):
""" """
Sent when a user is unbanned from a guild. Sent when a user is unbanned from a guild.
Attributes
-----
guild_id : snowflake
The ID of the guild the user is being unbanned from.
user : :class:`disco.types.user.User`
The user being unbanned from the guild.
""" """
pass
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
class GuildEmojisUpdate(GatewayEvent): class GuildEmojisUpdate(GatewayEvent):
""" """
Sent when a guild's emojis are updated. Sent when a guild's emojis are updated.
Attributes
-----
guild_id : snowflake
The ID of the guild the emojis are being updated in.
emojis : list[:class:`disco.types.guild.Emoji`]
The new set of emojis for the guild
""" """
pass guild_id = Field(snowflake)
emojis = ListField(GuildEmoji)
class GuildIntegrationsUpdate(GatewayEvent): class GuildIntegrationsUpdate(GatewayEvent):
""" """
Sent when a guild's integrations are updated. Sent when a guild's integrations are updated.
Attributes
-----
guild_id : snowflake
The ID of the guild integrations where updated in.
""" """
pass guild_id = Field(snowflake)
class GuildMembersChunk(GatewayEvent): class GuildMembersChunk(GatewayEvent):
""" """
Sent in response to a member's chunk request. Sent in response to a member's chunk request.
Attributes
-----
guild_id : snowflake
The ID of the guild this member chunk is for.
members : list[:class:`disco.types.guild.GuildMember`]
The chunk of members.
""" """
guild_id = Field(snowflake) guild_id = Field(snowflake)
members = Field(listof(GuildMember)) members = ListField(GuildMember)
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@wraps_model(GuildMember, alias='member') @wraps_model(GuildMember, alias='member')
class GuildMemberAdd(GatewayEvent): class GuildMemberAdd(GatewayEvent):
""" """
Sent when a user joins a guild. Sent when a user joins a guild.
Attributes
-----
member : :class:`disco.types.guild.GuildMember`
The member that has joined the guild.
""" """
pass
@proxy('user')
class GuildMemberRemove(GatewayEvent): class GuildMemberRemove(GatewayEvent):
""" """
Sent when a user leaves a guild (via leaving, kicking, or banning). Sent when a user leaves a guild (via leaving, kicking, or banning).
Attributes
-----
guild_id : snowflake
The ID of the guild the member left from.
user : :class:`disco.types.user.User`
The user who was removed from the guild.
""" """
guild_id = Field(snowflake)
user = Field(User) user = Field(User)
guild_id = Field(snowflake)
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@wraps_model(GuildMember, alias='member') @wraps_model(GuildMember, alias='member')
class GuildMemberUpdate(GatewayEvent): class GuildMemberUpdate(GatewayEvent):
""" """
Sent when a guilds member is updated. Sent when a guilds member is updated.
Attributes
-----
member : :class:`disco.types.guild.GuildMember`
The member being updated
""" """
pass
@proxy('role')
class GuildRoleCreate(GatewayEvent): class GuildRoleCreate(GatewayEvent):
""" """
Sent when a role is created. Sent when a role is created.
Attributes
-----
guild_id : snowflake
The ID of the guild where the role was created.
role : :class:`disco.types.guild.Role`
The role that was created.
""" """
guild_id = Field(snowflake)
role = Field(Role) role = Field(Role)
guild_id = Field(snowflake)
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@proxy('role')
class GuildRoleUpdate(GuildRoleCreate): class GuildRoleUpdate(GuildRoleCreate):
""" """
Sent when a role is updated. Sent when a role is updated.
Attributes
-----
guild_id : snowflake
The ID of the guild where the role was created.
role : :class:`disco.types.guild.Role`
The role that was created.
""" """
pass
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
class GuildRoleDelete(GatewayEvent): class GuildRoleDelete(GatewayEvent):
""" """
Sent when a role is deleted. Sent when a role is deleted.
Attributes
-----
guild_id : snowflake
The ID of the guild where the role is being deleted.
role_id : snowflake
The id of the role being deleted.
""" """
guild_id = Field(snowflake) guild_id = Field(snowflake)
role_id = Field(snowflake) role_id = Field(snowflake)
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@wraps_model(Message) @wraps_model(Message)
class MessageCreate(GatewayEvent): class MessageCreate(GatewayEvent):
""" """
Sent when a message is created. Sent when a message is created.
Attributes
-----
message : :class:`disco.types.message.Message`
The message being created.
""" """
@ -269,55 +450,124 @@ class MessageCreate(GatewayEvent):
class MessageUpdate(MessageCreate): class MessageUpdate(MessageCreate):
""" """
Sent when a message is updated/edited. Sent when a message is updated/edited.
Attributes
-----
message : :class:`disco.types.message.Message`
The message being updated.
""" """
pass
class MessageDelete(GatewayEvent): class MessageDelete(GatewayEvent):
""" """
Sent when a message is deleted. Sent when a message is deleted.
Attributes
-----
id : snowflake
The ID of message being deleted.
channel_id : snowflake
The ID of the channel the message was deleted in.
""" """
id = Field(snowflake) id = Field(snowflake)
channel_id = Field(snowflake) channel_id = Field(snowflake)
@property
def channel(self):
return self.client.state.channels.get(self.channel_id)
@property
def guild(self):
return self.channel.guild
class MessageDeleteBulk(GatewayEvent): class MessageDeleteBulk(GatewayEvent):
""" """
Sent when multiple messages are deleted from a channel. Sent when multiple messages are deleted from a channel.
Attributes
-----
channel_id : snowflake
The channel the messages are being deleted in.
ids : list[snowflake]
List of messages being deleted in the channel.
""" """
channel_id = Field(snowflake) channel_id = Field(snowflake)
ids = Field(listof(snowflake)) ids = ListField(snowflake)
@property
def channel(self):
return self.client.state.channels.get(self.channel_id)
@property
def guild(self):
return self.channel.guild
@wraps_model(Presence) @wraps_model(Presence)
class PresenceUpdate(GatewayEvent): class PresenceUpdate(GatewayEvent):
""" """
Sent when a user's presence is updated. Sent when a user's presence is updated.
Attributes
-----
presence : :class:`disco.types.user.Presence`
The updated presence object.
guild_id : snowflake
The guild this presence update is for.
roles : list[snowflake]
List of roles the user from the presence is part of.
""" """
guild_id = Field(snowflake) guild_id = Field(snowflake)
roles = Field(listof(snowflake)) roles = ListField(snowflake)
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
class TypingStart(GatewayEvent): class TypingStart(GatewayEvent):
""" """
Sent when a user begins typing in a channel. Sent when a user begins typing in a channel.
Attributes
-----
channel_id : snowflake
The ID of the channel where the user is typing.
user_id : snowflake
The ID of the user who is typing.
timestamp : datetime
When the user started typing.
""" """
channel_id = Field(snowflake) channel_id = Field(snowflake)
user_id = Field(snowflake) user_id = Field(snowflake)
timestamp = Field(snowflake) timestamp = Field(datetime)
@wraps_model(VoiceState, alias='state') @wraps_model(VoiceState, alias='state')
class VoiceStateUpdate(GatewayEvent): class VoiceStateUpdate(GatewayEvent):
""" """
Sent when a users voice state changes. Sent when a users voice state changes.
Attributes
-----
state : :class:`disco.models.voice.VoiceState`
The voice state which was updated.
""" """
pass
class VoiceServerUpdate(GatewayEvent): class VoiceServerUpdate(GatewayEvent):
""" """
Sent when a voice server is updated. Sent when a voice server is updated.
Attributes
-----
token : str
The token for the voice server.
endpoint : str
The endpoint for the voice server.
guild_id : snowflake
The guild ID this voice server update is for.
""" """
token = Field(str) token = Field(str)
endpoint = Field(str) endpoint = Field(str)
@ -327,6 +577,94 @@ class VoiceServerUpdate(GatewayEvent):
class WebhooksUpdate(GatewayEvent): class WebhooksUpdate(GatewayEvent):
""" """
Sent when a channels webhooks are updated. Sent when a channels webhooks are updated.
Attributes
-----
channel_id : snowflake
The channel ID this webhooks update is for.
guild_id : snowflake
The guild ID this webhooks update is for.
""" """
channel_id = Field(snowflake) channel_id = Field(snowflake)
guild_id = Field(snowflake) guild_id = Field(snowflake)
class MessageReactionAdd(GatewayEvent):
"""
Sent when a reaction is added to a message.
Attributes
----------
channel_id : snowflake
The channel ID the message is in.
messsage_id : snowflake
The ID of the message for which the reaction was added too.
user_id : snowflake
The ID of the user who added the reaction.
emoji : :class:`disco.types.message.MessageReactionEmoji`
The emoji which was added.
"""
channel_id = Field(snowflake)
message_id = Field(snowflake)
user_id = Field(snowflake)
emoji = Field(MessageReactionEmoji)
@property
def channel(self):
return self.client.state.channels.get(self.channel_id)
@property
def guild(self):
return self.channel.guild
class MessageReactionRemove(GatewayEvent):
"""
Sent when a reaction is removed from a message.
Attributes
----------
channel_id : snowflake
The channel ID the message is in.
messsage_id : snowflake
The ID of the message for which the reaction was removed from.
user_id : snowflake
The ID of the user who originally added the reaction.
emoji : :class:`disco.types.message.MessageReactionEmoji`
The emoji which was removed.
"""
channel_id = Field(snowflake)
message_id = Field(snowflake)
user_id = Field(snowflake)
emoji = Field(MessageReactionEmoji)
@property
def channel(self):
return self.client.state.channels.get(self.channel_id)
@property
def guild(self):
return self.channel.guild
class MessageReactionRemoveAll(GatewayEvent):
"""
Sent when all reactions are removed from a message.
Attributes
----------
channel_id : snowflake
The channel ID the message is in.
message_id : snowflake
The ID of the message for which the reactions where removed from.
"""
channel_id = Field(snowflake)
message_id = Field(snowflake)
@property
def channel(self):
return self.client.state.channels.get(self.channel_id)
@property
def guild(self):
return self.channel.guild

91
disco/gateway/ipc.py

@ -0,0 +1,91 @@
import random
import gevent
import string
import weakref
from holster.enum import Enum
from disco.util.logging import LoggingClass
from disco.util.serializer import dump_function, load_function
def get_random_str(size):
return ''.join([random.choice(string.printable) for _ in range(size)])
IPCMessageType = Enum(
'CALL_FUNC',
'GET_ATTR',
'EXECUTE',
'RESPONSE',
)
class GIPCProxy(LoggingClass):
def __init__(self, obj, pipe):
super(GIPCProxy, self).__init__()
self.obj = obj
self.pipe = pipe
self.results = weakref.WeakValueDictionary()
gevent.spawn(self.read_loop)
def resolve(self, parts):
base = self.obj
for part in parts:
base = getattr(base, part)
return base
def send(self, typ, data):
self.pipe.put((typ.value, data))
def handle(self, mtype, data):
if mtype == IPCMessageType.CALL_FUNC:
nonce, func, args, kwargs = data
res = self.resolve(func)(*args, **kwargs)
self.send(IPCMessageType.RESPONSE, (nonce, res))
elif mtype == IPCMessageType.GET_ATTR:
nonce, path = data
self.send(IPCMessageType.RESPONSE, (nonce, self.resolve(path)))
elif mtype == IPCMessageType.EXECUTE:
nonce, raw = data
func = load_function(raw)
try:
result = func(self.obj)
except Exception:
self.log.exception('Failed to EXECUTE: ')
result = None
self.send(IPCMessageType.RESPONSE, (nonce, result))
elif mtype == IPCMessageType.RESPONSE:
nonce, res = data
if nonce in self.results:
self.results[nonce].set(res)
def read_loop(self):
while True:
mtype, data = self.pipe.get()
try:
self.handle(mtype, data)
except:
self.log.exception('Error in GIPCProxy:')
def execute(self, func):
nonce = get_random_str(32)
raw = dump_function(func)
self.results[nonce] = result = gevent.event.AsyncResult()
self.pipe.put((IPCMessageType.EXECUTE.value, (nonce, raw)))
return result
def get(self, path):
nonce = get_random_str(32)
self.results[nonce] = result = gevent.event.AsyncResult()
self.pipe.put((IPCMessageType.GET_ATTR.value, (nonce, path)))
return result
def call(self, path, *args, **kwargs):
nonce = get_random_str(32)
self.results[nonce] = result = gevent.event.AsyncResult()
self.pipe.put((IPCMessageType.CALL_FUNC.value, (nonce, path, args, kwargs)))
return result

4
disco/gateway/packets.py

@ -1,7 +1,7 @@
from holster.enum import Enum from holster.enum import Enum
SEND = object() SEND = 1
RECV = object() RECV = 2
OPCode = Enum( OPCode = Enum(
DISPATCH=0, DISPATCH=0,

104
disco/gateway/sharder.py

@ -0,0 +1,104 @@
from __future__ import absolute_import
import gipc
import gevent
import pickle
import logging
import marshal
from six.moves import range
from disco.client import Client
from disco.bot import Bot, BotConfig
from disco.api.client import APIClient
from disco.gateway.ipc import GIPCProxy
from disco.util.logging import setup_logging
from disco.util.snowflake import calculate_shard
from disco.util.serializer import dump_function, load_function
def run_shard(config, shard_id, pipe):
setup_logging(
level=logging.INFO,
format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(shard_id)
)
config.shard_id = shard_id
client = Client(config)
bot = Bot(client, BotConfig(config.bot))
bot.sharder = GIPCProxy(bot, pipe)
bot.shards = ShardHelper(config.shard_count, bot)
bot.run_forever()
class ShardHelper(object):
def __init__(self, count, bot):
self.count = count
self.bot = bot
def keys(self):
for sid in range(self.count):
yield sid
def on(self, id, func):
if id == self.bot.client.config.shard_id:
result = gevent.event.AsyncResult()
result.set(func(self.bot))
return result
return self.bot.sharder.call(('run_on', ), id, dump_function(func))
def all(self, func, timeout=None):
pool = gevent.pool.Pool(self.count)
return dict(zip(range(self.count), pool.imap(lambda i: self.on(i, func).wait(timeout=timeout), range(self.count))))
def for_id(self, sid, func):
shard = calculate_shard(self.count, sid)
return self.on(shard, func)
class AutoSharder(object):
def __init__(self, config):
self.config = config
self.client = APIClient(config.token)
self.shards = {}
self.config.shard_count = self.client.gateway_bot_get()['shards']
def run_on(self, sid, raw):
func = load_function(raw)
return self.shards[sid].execute(func).wait(timeout=15)
def run(self):
for shard in range(self.config.shard_count):
if self.config.manhole_enable and shard != 0:
self.config.manhole_enable = False
self.start_shard(shard)
gevent.sleep(6)
logging.basicConfig(
level=logging.INFO,
format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(id)
)
@staticmethod
def dumps(data):
if isinstance(data, (basestring, int, long, bool, list, set, dict)):
return '\x01' + marshal.dumps(data)
elif isinstance(data, object) and data.__class__.__name__ == 'code':
return '\x01' + marshal.dumps(data)
else:
return '\x02' + pickle.dumps(data)
@staticmethod
def loads(data):
enc_type = data[0]
if enc_type == '\x01':
return marshal.loads(data[1:])
elif enc_type == '\x02':
return pickle.loads(data[1:])
def start_shard(self, sid):
cpipe, ppipe = gipc.pipe(duplex=True, encoder=self.dumps, decoder=self.loads)
gipc.start_process(run_shard, (self.config, sid, cpipe))
self.shards[sid] = GIPCProxy(self, ppipe)

68
disco/state.py

@ -1,10 +1,11 @@
import six import six
import weakref
import inflection import inflection
from collections import deque, namedtuple from collections import deque, namedtuple
from weakref import WeakValueDictionary
from gevent.event import Event from gevent.event import Event
from disco.types.base import UNSET
from disco.util.config import Config from disco.util.config import Config
from disco.util.hashmap import HashMap, DefaultHashMap from disco.util.hashmap import HashMap, DefaultHashMap
@ -88,7 +89,7 @@ class State(object):
EVENTS = [ EVENTS = [
'Ready', 'GuildCreate', 'GuildUpdate', 'GuildDelete', 'GuildMemberAdd', 'GuildMemberRemove', 'Ready', 'GuildCreate', 'GuildUpdate', 'GuildDelete', 'GuildMemberAdd', 'GuildMemberRemove',
'GuildMemberUpdate', 'GuildMembersChunk', 'GuildRoleCreate', 'GuildRoleUpdate', 'GuildRoleDelete', 'GuildMemberUpdate', 'GuildMembersChunk', 'GuildRoleCreate', 'GuildRoleUpdate', 'GuildRoleDelete',
'ChannelCreate', 'ChannelUpdate', 'ChannelDelete', 'VoiceStateUpdate', 'MessageCreate', 'GuildEmojisUpdate', 'ChannelCreate', 'ChannelUpdate', 'ChannelDelete', 'VoiceStateUpdate', 'MessageCreate',
'PresenceUpdate' 'PresenceUpdate'
] ]
@ -102,9 +103,9 @@ class State(object):
self.me = None self.me = None
self.dms = HashMap() self.dms = HashMap()
self.guilds = HashMap() self.guilds = HashMap()
self.channels = HashMap(WeakValueDictionary()) self.channels = HashMap(weakref.WeakValueDictionary())
self.users = HashMap(WeakValueDictionary()) self.users = HashMap(weakref.WeakValueDictionary())
self.voice_states = HashMap(WeakValueDictionary()) self.voice_states = HashMap(weakref.WeakValueDictionary())
# If message tracking is enabled, listen to those events # If message tracking is enabled, listen to those events
if self.config.track_messages: if self.config.track_messages:
@ -117,7 +118,7 @@ class State(object):
def unbind(self): def unbind(self):
""" """
Unbinds all bound event listeners for this state object Unbinds all bound event listeners for this state object.
""" """
map(lambda k: k.unbind(), self.listeners) map(lambda k: k.unbind(), self.listeners)
self.listeners = [] self.listeners = []
@ -185,11 +186,19 @@ class State(object):
for member in six.itervalues(event.guild.members): for member in six.itervalues(event.guild.members):
self.users[member.user.id] = member.user self.users[member.user.id] = member.user
for voice_state in six.itervalues(event.guild.voice_states):
self.voice_states[voice_state.session_id] = voice_state
if self.config.sync_guild_members: if self.config.sync_guild_members:
event.guild.sync() event.guild.sync()
def on_guild_update(self, event): def on_guild_update(self, event):
self.guilds[event.guild.id].update(event.guild) self.guilds[event.guild.id].update(event.guild, ignored=[
'channels',
'members',
'voice_states',
'presences'
])
def on_guild_delete(self, event): def on_guild_delete(self, event):
if event.id in self.guilds: if event.id in self.guilds:
@ -208,6 +217,10 @@ class State(object):
if event.channel.id in self.channels: if event.channel.id in self.channels:
self.channels[event.channel.id].update(event.channel) self.channels[event.channel.id].update(event.channel)
if event.overwrites is not UNSET:
self.channels[event.channel.id].overwrites = event.overwrites
self.channels[event.channel.id].after_load()
def on_channel_delete(self, event): def on_channel_delete(self, event):
if event.channel.is_guild and event.channel.guild and event.channel.id in event.channel.guild.channels: if event.channel.is_guild and event.channel.guild and event.channel.id in event.channel.guild.channels:
del event.channel.guild.channels[event.channel.id] del event.channel.guild.channels[event.channel.id]
@ -215,18 +228,22 @@ class State(object):
del self.dms[event.channel.id] del self.dms[event.channel.id]
def on_voice_state_update(self, event): def on_voice_state_update(self, event):
# Happy path: we have the voice state and want to update/delete it # Existing connection, we are either moving channels or disconnecting
guild = self.guilds.get(event.state.guild_id) if event.state.session_id in self.voice_states:
if not guild: # Moving channels
return
if event.state.session_id in guild.voice_states:
if event.state.channel_id: if event.state.channel_id:
guild.voice_states[event.state.session_id].update(event.state) self.voice_states[event.state.session_id].update(event.state)
# Disconnection
else: else:
del guild.voice_states[event.state.session_id] if event.state.guild_id in self.guilds:
if event.state.session_id in self.guilds[event.state.guild_id].voice_states:
del self.guilds[event.state.guild_id].voice_states[event.state.session_id]
del self.voice_states[event.state.session_id]
# New connection
elif event.state.channel_id: elif event.state.channel_id:
guild.voice_states[event.state.session_id] = event.state if event.state.guild_id in self.guilds:
self.guilds[event.state.guild_id].voice_states[event.state.session_id] = event.state
self.voice_states[event.state.session_id] = event.state
def on_guild_member_add(self, event): def on_guild_member_add(self, event):
if event.member.user.id not in self.users: if event.member.user.id not in self.users:
@ -243,6 +260,9 @@ class State(object):
if event.member.guild_id not in self.guilds: if event.member.guild_id not in self.guilds:
return return
if event.member.id not in self.guilds[event.member.guild_id].members:
return
self.guilds[event.member.guild_id].members[event.member.id].update(event.member) self.guilds[event.member.guild_id].members[event.member.id].update(event.member)
def on_guild_member_remove(self, event): def on_guild_member_remove(self, event):
@ -285,6 +305,22 @@ class State(object):
del self.guilds[event.guild_id].roles[event.role_id] del self.guilds[event.guild_id].roles[event.role_id]
def on_guild_emojis_update(self, event):
if event.guild_id not in self.guilds:
return
self.guilds[event.guild_id].emojis = HashMap({i.id: i for i in event.emojis})
def on_presence_update(self, event): def on_presence_update(self, event):
if event.user.id in self.users: if event.user.id in self.users:
self.users[event.user.id].update(event.presence.user)
self.users[event.user.id].presence = event.presence self.users[event.user.id].presence = event.presence
event.presence.user = self.users[event.user.id]
if event.guild_id not in self.guilds:
return
if event.user.id not in self.guilds[event.guild_id].members:
return
self.guilds[event.guild_id].members[event.user.id].user.update(event.user)

1
disco/types/__init__.py

@ -1,3 +1,4 @@
from disco.types.base import UNSET
from disco.types.channel import Channel from disco.types.channel import Channel
from disco.types.guild import Guild, GuildMember, Role from disco.types.guild import Guild, GuildMember, Role
from disco.types.user import User from disco.types.user import User

272
disco/types/base.py

@ -3,7 +3,7 @@ import gevent
import inspect import inspect
import functools import functools
from holster.enum import BaseEnumMeta from holster.enum import BaseEnumMeta, EnumAttr
from datetime import datetime as real_datetime from datetime import datetime as real_datetime
from disco.util.functional import CachedSlotProperty from disco.util.functional import CachedSlotProperty
@ -15,45 +15,61 @@ DATETIME_FORMATS = [
] ]
def get_item_by_path(obj, path):
for part in path.split('.'):
obj = getattr(obj, part)
return obj
class Unset(object):
def __nonzero__(self):
return False
UNSET = Unset()
class ConversionError(Exception): class ConversionError(Exception):
def __init__(self, field, raw, e): def __init__(self, field, raw, e):
super(ConversionError, self).__init__( super(ConversionError, self).__init__(
'Failed to convert `{}` (`{}`) to {}: {}'.format( 'Failed to convert `{}` (`{}`) to {}: {}'.format(
str(raw)[:144], field.src_name, field.typ, e)) str(raw)[:144], field.src_name, field.true_type, e))
if six.PY3:
self.__cause__ = e
class FieldType(object):
def __init__(self, typ):
if isinstance(typ, FieldType) or inspect.isclass(typ) and issubclass(typ, Model):
self.typ = typ
elif isinstance(typ, BaseEnumMeta):
self.typ = lambda raw, _: typ.get(raw)
elif typ is None:
self.typ = lambda x, y: None
else:
self.typ = lambda raw, _: typ(raw)
def try_convert(self, raw, client): class Field(object):
pass def __init__(self, value_type, alias=None, default=None, create=True, ignore_dump=None, cast=None, **kwargs):
# TODO: fix default bullshit
def __call__(self, raw, client): self.true_type = value_type
return self.try_convert(raw, client) self.src_name = alias
self.dst_name = None
self.ignore_dump = ignore_dump or []
self.cast = cast
self.metadata = kwargs
if default is not None:
self.default = default
elif not hasattr(self, 'default'):
self.default = None
class Field(FieldType): self.deserializer = None
def __init__(self, typ, alias=None, default=None):
super(Field, self).__init__(typ)
# Set names if value_type:
self.src_name = alias self.deserializer = self.type_to_deserializer(value_type)
self.dst_name = None
self.default = default if isinstance(self.deserializer, Field) and self.default is None:
self.default = self.deserializer.default
elif inspect.isclass(self.deserializer) and issubclass(self.deserializer, Model) and self.default is None and create:
self.default = self.deserializer
if isinstance(self.typ, FieldType): @property
self.default = self.typ.default def name(self):
return None
def set_name(self, name): @name.setter
def name(self, name):
if not self.dst_name: if not self.dst_name:
self.dst_name = name self.dst_name = name
@ -65,31 +81,82 @@ class Field(FieldType):
def try_convert(self, raw, client): def try_convert(self, raw, client):
try: try:
return self.typ(raw, client) return self.deserializer(raw, client)
except Exception as e: except Exception as e:
six.raise_from(ConversionError(self, raw, e), e) six.reraise(ConversionError, ConversionError(self, raw, e))
@staticmethod
def type_to_deserializer(typ):
if isinstance(typ, Field) or inspect.isclass(typ) and issubclass(typ, Model):
return typ
elif isinstance(typ, BaseEnumMeta):
return lambda raw, _: typ.get(raw)
elif typ is None:
return lambda x, y: None
else:
return lambda raw, _: typ(raw)
@staticmethod
def serialize(value, inst=None):
if isinstance(value, EnumAttr):
return value.value
elif isinstance(value, Model):
return value.to_dict(ignore=(inst.ignore_dump if inst else []))
else:
if inst and inst.cast:
return inst.cast(value)
return value
class _Dict(FieldType): def __call__(self, raw, client):
return self.try_convert(raw, client)
class DictField(Field):
default = HashMap default = HashMap
def __init__(self, typ, key=None): def __init__(self, key_type, value_type=None, **kwargs):
super(_Dict, self).__init__(typ) super(DictField, self).__init__({}, **kwargs)
self.key = key self.true_key_type = key_type
self.true_value_type = value_type
self.key_de = self.type_to_deserializer(key_type)
self.value_de = self.type_to_deserializer(value_type or key_type)
@staticmethod
def serialize(value, inst=None):
return {
Field.serialize(k): Field.serialize(v) for k, v in six.iteritems(value)
if k not in (inst.ignore_dump if inst else [])
}
def try_convert(self, raw, client): def try_convert(self, raw, client):
if self.key: return HashMap({
converted = [self.typ(i, client) for i in raw] self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw)
return HashMap({getattr(i, self.key): i for i in converted}) })
else:
return HashMap({k: self.typ(v, client) for k, v in six.iteritems(raw)})
class _List(FieldType): class ListField(Field):
default = list default = list
@staticmethod
def serialize(value, inst=None):
return list(map(Field.serialize, value))
def try_convert(self, raw, client):
return [self.deserializer(i, client) for i in raw]
class AutoDictField(Field):
default = HashMap
def __init__(self, value_type, key, **kwargs):
super(AutoDictField, self).__init__({}, **kwargs)
self.value_de = self.type_to_deserializer(value_type)
self.key = key
def try_convert(self, raw, client): def try_convert(self, raw, client):
return [self.typ(i, client) for i in raw] return HashMap({
getattr(b, self.key): b for b in (self.value_de(a, client) for a in raw)
})
def _make(typ, data, client): def _make(typ, data, client):
@ -104,37 +171,19 @@ def snowflake(data):
def enum(typ): def enum(typ):
def _f(data): def _f(data):
if isinstance(data, str):
data = data.lower()
return typ.get(data) if data is not None else None return typ.get(data) if data is not None else None
return _f return _f
def listof(*args, **kwargs):
return _List(*args, **kwargs)
def dictof(*args, **kwargs):
return _Dict(*args, **kwargs)
def lazy_datetime(data):
if not data:
return property(lambda: None)
def get():
for fmt in DATETIME_FORMATS:
try:
return real_datetime.strptime(data.rsplit('+', 1)[0], fmt)
except (ValueError, TypeError):
continue
raise ValueError('Failed to conver `{}` to datetime'.format(data))
return property(get)
def datetime(data): def datetime(data):
if not data: if not data:
return None return None
if isinstance(data, int):
return real_datetime.utcfromtimestamp(data)
for fmt in DATETIME_FORMATS: for fmt in DATETIME_FORMATS:
try: try:
return real_datetime.strptime(data.rsplit('+', 1)[0], fmt) return real_datetime.strptime(data.rsplit('+', 1)[0], fmt)
@ -145,6 +194,9 @@ def datetime(data):
def text(obj): def text(obj):
if obj is None:
return None
if six.PY2: if six.PY2:
if isinstance(obj, str): if isinstance(obj, str):
return obj.decode('utf-8') return obj.decode('utf-8')
@ -154,6 +206,9 @@ def text(obj):
def binary(obj): def binary(obj):
if obj is None:
return None
if six.PY2: if six.PY2:
if isinstance(obj, str): if isinstance(obj, str):
return obj.decode('utf-8') return obj.decode('utf-8')
@ -165,13 +220,16 @@ def binary(obj):
def with_equality(field): def with_equality(field):
class T(object): class T(object):
def __eq__(self, other): def __eq__(self, other):
return getattr(self, field) == getattr(other, field) if isinstance(other, self.__class__):
return getattr(self, field) == getattr(other, field)
else:
return getattr(self, field) == other
return T return T
def with_hash(field): def with_hash(field):
class T(object): class T(object):
def __hash__(self, other): def __hash__(self):
return hash(getattr(self, field)) return hash(getattr(self, field))
return T return T
@ -182,7 +240,7 @@ SlottedModel = None
class ModelMeta(type): class ModelMeta(type):
def __new__(cls, name, parents, dct): def __new__(mcs, name, parents, dct):
fields = {} fields = {}
for parent in parents: for parent in parents:
@ -193,7 +251,7 @@ class ModelMeta(type):
if not isinstance(v, Field): if not isinstance(v, Field):
continue continue
v.set_name(k) v.name = k
fields[k] = v fields[k] = v
if SlottedModel and any(map(lambda k: issubclass(k, SlottedModel), parents)): if SlottedModel and any(map(lambda k: issubclass(k, SlottedModel), parents)):
@ -209,7 +267,7 @@ class ModelMeta(type):
dct = {k: v for k, v in six.iteritems(dct) if k not in fields} dct = {k: v for k, v in six.iteritems(dct) if k not in fields}
dct['_fields'] = fields dct['_fields'] = fields
return super(ModelMeta, cls).__new__(cls, name, parents, dct) return super(ModelMeta, mcs).__new__(mcs, name, parents, dct)
class AsyncChainable(object): class AsyncChainable(object):
@ -233,23 +291,49 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
else: else:
obj = kwargs obj = kwargs
for name, field in six.iteritems(self.__class__._fields): self.load(obj)
if field.src_name not in obj or obj[field.src_name] is None: self.validate()
if field.has_default():
default = field.default() if callable(field.default) else field.default def validate(self):
else: pass
default = None
setattr(self, field.dst_name, default) @property
def _fields(self):
return self.__class__._fields
def load(self, obj, consume=False, skip=None):
return self.load_into(self, obj, consume, skip)
def load_into(self, inst, obj, consume=False, skip=None):
for name, field in six.iteritems(self._fields):
should_skip = skip and name in skip
if consume and not should_skip:
raw = obj.pop(field.src_name, UNSET)
else:
raw = obj.get(field.src_name, UNSET)
# If the field is unset/none, and we have a default we need to set it
if (raw in (None, UNSET) or should_skip) and field.has_default():
default = field.default() if callable(field.default) else field.default
setattr(inst, field.dst_name, default)
continue
# Otherwise if the field is UNSET and has no default, skip conversion
if raw is UNSET or should_skip:
setattr(inst, field.dst_name, raw)
continue continue
value = field.try_convert(obj[field.src_name], self.client) value = field.try_convert(raw, self.client)
setattr(self, field.dst_name, value) setattr(inst, field.dst_name, value)
def update(self, other): def update(self, other, ignored=None):
for name in six.iterkeys(self.__class__._fields): for name in six.iterkeys(self._fields):
value = getattr(other, name) if ignored and name in ignored:
if value: continue
setattr(self, name, value)
if hasattr(other, name) and not getattr(other, name) is UNSET:
setattr(self, name, getattr(other, name))
# Clear cached properties # Clear cached properties
for name in dir(type(self)): for name in dir(type(self)):
@ -259,8 +343,16 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
except: except:
pass pass
def to_dict(self): def to_dict(self, ignore=None):
return {k: getattr(self, k) for k in six.iterkeys(self.__class__._fields)} obj = {}
for name, field in six.iteritems(self.__class__._fields):
if ignore and name in ignore:
continue
if getattr(self, name) == UNSET:
continue
obj[name] = field.serialize(getattr(self, name), field)
return obj
@classmethod @classmethod
def create(cls, client, data, **kwargs): def create(cls, client, data, **kwargs):
@ -269,8 +361,16 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
return inst return inst
@classmethod @classmethod
def create_map(cls, client, data): def create_map(cls, client, data, **kwargs):
return list(map(functools.partial(cls.create, client), data)) return list(map(functools.partial(cls.create, client, **kwargs), data))
@classmethod
def create_hash(cls, client, key, data, **kwargs):
return HashMap({
get_item_by_path(item, key): item
for item in [
cls.create(client, item, **kwargs) for item in data]
})
@classmethod @classmethod
def attach(cls, it, data): def attach(cls, it, data):

102
disco/types/channel.py

@ -1,11 +1,12 @@
import six import six
from six.moves import map
from holster.enum import Enum from holster.enum import Enum
from disco.util.snowflake import to_snowflake from disco.util.snowflake import to_snowflake
from disco.util.functional import cached_property, one_or_many, chunks from disco.util.functional import cached_property, one_or_many, chunks
from disco.types.user import User from disco.types.user import User
from disco.types.base import SlottedModel, Field, snowflake, enum, listof, dictof, text from disco.types.base import SlottedModel, Field, AutoDictField, snowflake, enum, text
from disco.types.permissions import Permissions, Permissible, PermissionValue from disco.types.permissions import Permissions, Permissible, PermissionValue
from disco.voice.client import VoiceClient from disco.voice.client import VoiceClient
@ -33,7 +34,7 @@ class ChannelSubType(SlottedModel):
class PermissionOverwrite(ChannelSubType): class PermissionOverwrite(ChannelSubType):
""" """
A PermissionOverwrite for a :class:`Channel` A PermissionOverwrite for a :class:`Channel`.
Attributes Attributes
---------- ----------
@ -48,8 +49,8 @@ class PermissionOverwrite(ChannelSubType):
""" """
id = Field(snowflake) id = Field(snowflake)
type = Field(enum(PermissionOverwriteType)) type = Field(enum(PermissionOverwriteType))
allow = Field(PermissionValue) allow = Field(PermissionValue, cast=int)
deny = Field(PermissionValue) deny = Field(PermissionValue, cast=int)
channel_id = Field(snowflake) channel_id = Field(snowflake)
@ -57,22 +58,29 @@ class PermissionOverwrite(ChannelSubType):
def create(cls, channel, entity, allow=0, deny=0): def create(cls, channel, entity, allow=0, deny=0):
from disco.types.guild import Role from disco.types.guild import Role
type = PermissionOverwriteType.ROLE if isinstance(entity, Role) else PermissionOverwriteType.MEMBER ptype = PermissionOverwriteType.ROLE if isinstance(entity, Role) else PermissionOverwriteType.MEMBER
return cls( return cls(
client=channel.client, client=channel.client,
id=entity.id, id=entity.id,
type=type, type=ptype,
allow=allow, allow=allow,
deny=deny, deny=deny,
channel_id=channel.id channel_id=channel.id
).save() ).save()
@property
def compiled(self):
value = PermissionValue()
value -= self.deny
value += self.allow
return value
def save(self): def save(self):
self.client.api.channels_permissions_modify(self.channel_id, self.client.api.channels_permissions_modify(self.channel_id,
self.id, self.id,
self.allow.value or 0, self.allow.value or 0,
self.deny.value or 0, self.deny.value or 0,
self.type.name) self.type.name)
return self return self
def delete(self): def delete(self):
@ -81,7 +89,7 @@ class PermissionOverwrite(ChannelSubType):
class Channel(SlottedModel, Permissible): class Channel(SlottedModel, Permissible):
""" """
Represents a Discord Channel Represents a Discord Channel.
Attributes Attributes
---------- ----------
@ -111,18 +119,27 @@ class Channel(SlottedModel, Permissible):
last_message_id = Field(snowflake) last_message_id = Field(snowflake)
position = Field(int) position = Field(int)
bitrate = Field(int) bitrate = Field(int)
recipients = Field(listof(User)) recipients = AutoDictField(User, 'id')
type = Field(enum(ChannelType)) type = Field(enum(ChannelType))
overwrites = Field(dictof(PermissionOverwrite, key='id'), alias='permission_overwrites') overwrites = AutoDictField(PermissionOverwrite, 'id', alias='permission_overwrites')
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Channel, self).__init__(*args, **kwargs) super(Channel, self).__init__(*args, **kwargs)
self.after_load()
def after_load(self):
# TODO: hackfix
self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self}) self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self})
def __str__(self):
return u'#{}'.format(self.name)
def __repr__(self):
return u'<Channel {} ({})>'.format(self.id, self)
def get_permissions(self, user): def get_permissions(self, user):
""" """
Get the permissions a user has in the channel Get the permissions a user has in the channel.
Returns Returns
------- -------
@ -132,8 +149,8 @@ class Channel(SlottedModel, Permissible):
if not self.guild_id: if not self.guild_id:
return Permissions.ADMINISTRATOR return Permissions.ADMINISTRATOR
member = self.guild.members.get(user.id) member = self.guild.get_member(user)
base = self.guild.get_permissions(user) base = self.guild.get_permissions(member)
for ow in six.itervalues(self.overwrites): for ow in six.itervalues(self.overwrites):
if ow.id != user.id and ow.id not in member.roles: if ow.id != user.id and ow.id not in member.roles:
@ -144,48 +161,55 @@ class Channel(SlottedModel, Permissible):
return base return base
@property
def mention(self):
return '<#{}>'.format(self.id)
@property @property
def is_guild(self): def is_guild(self):
""" """
Whether this channel belongs to a guild Whether this channel belongs to a guild.
""" """
return self.type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE) return self.type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE)
@property @property
def is_dm(self): def is_dm(self):
""" """
Whether this channel is a DM (does not belong to a guild) Whether this channel is a DM (does not belong to a guild).
""" """
return self.type in (ChannelType.DM, ChannelType.GROUP_DM) return self.type in (ChannelType.DM, ChannelType.GROUP_DM)
@property @property
def is_voice(self): def is_voice(self):
""" """
Whether this channel supports voice Whether this channel supports voice.
""" """
return self.type in (ChannelType.GUILD_VOICE, ChannelType.GROUP_DM) return self.type in (ChannelType.GUILD_VOICE, ChannelType.GROUP_DM)
@property @property
def messages(self): def messages(self):
""" """
a default :class:`MessageIterator` for the channel a default :class:`MessageIterator` for the channel.
""" """
return self.messages_iter() return self.messages_iter()
@cached_property @cached_property
def guild(self): def guild(self):
""" """
Guild this channel belongs to (if relevant) Guild this channel belongs to (if relevant).
""" """
return self.client.state.guilds.get(self.guild_id) return self.client.state.guilds.get(self.guild_id)
def messages_iter(self, **kwargs): def messages_iter(self, **kwargs):
""" """
Creates a new :class:`MessageIterator` for the channel with the given Creates a new :class:`MessageIterator` for the channel with the given
keyword arguments keyword arguments.
""" """
return MessageIterator(self.client, self, **kwargs) return MessageIterator(self.client, self, **kwargs)
def get_message(self, message):
return self.client.api.channels_messages_get(self.id, to_snowflake(message))
def get_invites(self): def get_invites(self):
""" """
Returns Returns
@ -220,9 +244,9 @@ class Channel(SlottedModel, Permissible):
def create_webhook(self, name=None, avatar=None): def create_webhook(self, name=None, avatar=None):
return self.client.api.channels_webhooks_create(self.id, name, avatar) return self.client.api.channels_webhooks_create(self.id, name, avatar)
def send_message(self, content, nonce=None, tts=False): def send_message(self, content, nonce=None, tts=False, attachment=None, embed=None):
""" """
Send a message in this channel Send a message in this channel.
Parameters Parameters
---------- ----------
@ -238,11 +262,11 @@ class Channel(SlottedModel, Permissible):
:class:`disco.types.message.Message` :class:`disco.types.message.Message`
The created message. The created message.
""" """
return self.client.api.channels_messages_create(self.id, content, nonce, tts) return self.client.api.channels_messages_create(self.id, content, nonce, tts, attachment, embed)
def connect(self, *args, **kwargs): def connect(self, *args, **kwargs):
""" """
Connect to this channel over voice Connect to this channel over voice.
""" """
assert self.is_voice, 'Channel must support voice to connect' assert self.is_voice, 'Channel must support voice to connect'
vc = VoiceClient(self) vc = VoiceClient(self)
@ -275,17 +299,29 @@ class Channel(SlottedModel, Permissible):
List of messages (or message ids) to delete. All messages must originate List of messages (or message ids) to delete. All messages must originate
from this channel. from this channel.
""" """
messages = map(to_snowflake, messages) message_ids = list(map(to_snowflake, messages))
if not messages: if not message_ids:
return return
if len(messages) <= 2: if self.can(self.client.state.me, Permissions.MANAGE_MESSAGES) and len(messages) > 2:
for chunk in chunks(message_ids, 100):
self.client.api.channels_messages_delete_bulk(self.id, chunk)
else:
for msg in messages: for msg in messages:
self.delete_message(msg) self.delete_message(msg)
else:
for chunk in chunks(messages, 100): def delete(self):
self.client.api.channels_messages_delete_bulk(self.id, chunk) assert (self.is_dm or self.guild.can(self.client.state.me, Permissions.MANAGE_GUILD)), 'Invalid Permissions'
self.client.api.channels_delete(self.id)
def close(self):
"""
Closes a DM channel. This is intended as a safer version of `delete`,
enforcing that the channel is actually a DM.
"""
assert self.is_dm, 'Cannot close non-DM channel'
self.delete()
class MessageIterator(object): class MessageIterator(object):
@ -329,7 +365,7 @@ class MessageIterator(object):
def fill(self): def fill(self):
""" """
Fills the internal buffer up with :class:`disco.types.message.Message` objects from the API Fills the internal buffer up with :class:`disco.types.message.Message` objects from the API.
""" """
self._buffer = self.client.api.channels_messages_list( self._buffer = self.client.api.channels_messages_list(
self.channel.id, self.channel.id,

129
disco/types/guild.py

@ -6,10 +6,13 @@ from disco.gateway.packets import OPCode
from disco.api.http import APIException from disco.api.http import APIException
from disco.util.snowflake import to_snowflake from disco.util.snowflake import to_snowflake
from disco.util.functional import cached_property from disco.util.functional import cached_property
from disco.types.base import SlottedModel, Field, snowflake, listof, dictof, text, binary, enum from disco.types.base import (
from disco.types.user import User SlottedModel, Field, ListField, AutoDictField, snowflake, text, binary, enum, datetime
)
from disco.types.user import User, Presence
from disco.types.voice import VoiceState from disco.types.voice import VoiceState
from disco.types.channel import Channel from disco.types.channel import Channel
from disco.types.message import Emoji
from disco.types.permissions import PermissionValue, Permissions, Permissible from disco.types.permissions import PermissionValue, Permissions, Permissible
@ -18,21 +21,12 @@ VerificationLevel = Enum(
LOW=1, LOW=1,
MEDIUM=2, MEDIUM=2,
HIGH=3, HIGH=3,
EXTREME=4,
) )
class GuildSubType(SlottedModel): class GuildEmoji(Emoji):
guild_id = Field(None)
@cached_property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
class Emoji(GuildSubType):
""" """
An emoji object An emoji object.
Attributes Attributes
---------- ----------
@ -48,15 +42,27 @@ class Emoji(GuildSubType):
Roles this emoji is attached to. Roles this emoji is attached to.
""" """
id = Field(snowflake) id = Field(snowflake)
guild_id = Field(snowflake)
name = Field(text) name = Field(text)
require_colons = Field(bool) require_colons = Field(bool)
managed = Field(bool) managed = Field(bool)
roles = Field(listof(snowflake)) roles = ListField(snowflake)
def __str__(self):
return u'<:{}:{}>'.format(self.name, self.id)
class Role(GuildSubType): @property
def url(self):
return 'https://discordapp.com/api/emojis/{}.png'.format(self.id)
@cached_property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
class Role(SlottedModel):
""" """
A role object A role object.
Attributes Attributes
---------- ----------
@ -76,6 +82,7 @@ class Role(GuildSubType):
The position of this role in the hierarchy. The position of this role in the hierarchy.
""" """
id = Field(snowflake) id = Field(snowflake)
guild_id = Field(snowflake)
name = Field(text) name = Field(text)
hoist = Field(bool) hoist = Field(bool)
managed = Field(bool) managed = Field(bool)
@ -84,6 +91,9 @@ class Role(GuildSubType):
position = Field(int) position = Field(int)
mentionable = Field(bool) mentionable = Field(bool)
def __str__(self):
return self.name
def delete(self): def delete(self):
self.guild.delete_role(self) self.guild.delete_role(self)
@ -94,10 +104,19 @@ class Role(GuildSubType):
def mention(self): def mention(self):
return '<@{}>'.format(self.id) return '<@{}>'.format(self.id)
@cached_property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
class GuildMember(GuildSubType): class GuildBan(SlottedModel):
user = Field(User)
reason = Field(str)
class GuildMember(SlottedModel):
""" """
A GuildMember object A GuildMember object.
Attributes Attributes
---------- ----------
@ -121,8 +140,18 @@ class GuildMember(GuildSubType):
nick = Field(text) nick = Field(text)
mute = Field(bool) mute = Field(bool)
deaf = Field(bool) deaf = Field(bool)
joined_at = Field(str) joined_at = Field(datetime)
roles = Field(listof(snowflake)) roles = ListField(snowflake)
def __str__(self):
return self.user.__str__()
@property
def name(self):
"""
The nickname of this user if set, otherwise their username
"""
return self.nick or self.user.username
def get_voice_state(self): def get_voice_state(self):
""" """
@ -151,6 +180,12 @@ class GuildMember(GuildSubType):
""" """
self.guild.create_ban(self, delete_message_days) self.guild.create_ban(self, delete_message_days)
def unban(self):
"""
Unbans the member from the guild.
"""
self.guild.delete_ban(self)
def set_nickname(self, nickname=None): def set_nickname(self, nickname=None):
""" """
Sets the member's nickname (or clears it if None). Sets the member's nickname (or clears it if None).
@ -160,11 +195,19 @@ class GuildMember(GuildSubType):
nickname : Optional[str] nickname : Optional[str]
The nickname (or none to reset) to set. The nickname (or none to reset) to set.
""" """
self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '') if self.client.state.me.id == self.user.id:
self.client.api.guilds_members_me_nick(self.guild.id, nick=nickname or '')
else:
self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '')
def modify(self, **kwargs):
self.client.api.guilds_members_modify(self.guild.id, self.user.id, **kwargs)
def add_role(self, role): def add_role(self, role):
roles = self.roles + [role.id] self.client.api.guilds_members_roles_add(self.guild.id, self.user.id, to_snowflake(role))
self.client.api.guilds_members_modify(self.guild.id, self.user.id, roles=roles)
def remove_role(self, role):
self.client.api.guilds_members_roles_remove(self.guild.id, self.user.id, to_snowflake(role))
@cached_property @cached_property
def owner(self): def owner(self):
@ -179,14 +222,22 @@ class GuildMember(GuildSubType):
@property @property
def id(self): def id(self):
""" """
Alias to the guild members user id Alias to the guild members user id.
""" """
return self.user.id return self.user.id
@cached_property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@cached_property
def permissions(self):
return self.guild.get_permissions(self)
class Guild(SlottedModel, Permissible): class Guild(SlottedModel, Permissible):
""" """
A guild object A guild object.
Attributes Attributes
---------- ----------
@ -222,7 +273,7 @@ class Guild(SlottedModel, Permissible):
All of the guild's channels. All of the guild's channels.
roles : dict(snowflake, :class:`Role`) roles : dict(snowflake, :class:`Role`)
All of the guild's roles. All of the guild's roles.
emojis : dict(snowflake, :class:`Emoji`) emojis : dict(snowflake, :class:`GuildEmoji`)
All of the guild's emojis. All of the guild's emojis.
voice_states : dict(str, :class:`disco.types.voice.VoiceState`) voice_states : dict(str, :class:`disco.types.voice.VoiceState`)
All of the guild's voice states. All of the guild's voice states.
@ -239,12 +290,14 @@ class Guild(SlottedModel, Permissible):
embed_enabled = Field(bool) embed_enabled = Field(bool)
verification_level = Field(enum(VerificationLevel)) verification_level = Field(enum(VerificationLevel))
mfa_level = Field(int) mfa_level = Field(int)
features = Field(listof(str)) features = ListField(str)
members = Field(dictof(GuildMember, key='id')) members = AutoDictField(GuildMember, 'id')
channels = Field(dictof(Channel, key='id')) channels = AutoDictField(Channel, 'id')
roles = Field(dictof(Role, key='id')) roles = AutoDictField(Role, 'id')
emojis = Field(dictof(Emoji, key='id')) emojis = AutoDictField(GuildEmoji, 'id')
voice_states = Field(dictof(VoiceState, key='session_id')) voice_states = AutoDictField(VoiceState, 'session_id')
member_count = Field(int)
presences = ListField(Presence)
synced = Field(bool, default=False) synced = Field(bool, default=False)
@ -257,7 +310,7 @@ class Guild(SlottedModel, Permissible):
self.attach(six.itervalues(self.emojis), {'guild_id': self.id}) self.attach(six.itervalues(self.emojis), {'guild_id': self.id})
self.attach(six.itervalues(self.voice_states), {'guild_id': self.id}) self.attach(six.itervalues(self.voice_states), {'guild_id': self.id})
def get_permissions(self, user): def get_permissions(self, member):
""" """
Get the permissions a user has in this guild. Get the permissions a user has in this guild.
@ -266,10 +319,13 @@ class Guild(SlottedModel, Permissible):
:class:`disco.types.permissions.PermissionValue` :class:`disco.types.permissions.PermissionValue`
Computed permission value for the user. Computed permission value for the user.
""" """
if self.owner_id == user.id: if not isinstance(member, GuildMember):
member = self.get_member(member)
# Owner has all permissions
if self.owner_id == member.id:
return PermissionValue(Permissions.ADMINISTRATOR) return PermissionValue(Permissions.ADMINISTRATOR)
member = self.get_member(user)
value = PermissionValue(self.roles.get(self.id).permissions) value = PermissionValue(self.roles.get(self.id).permissions)
for role in map(self.roles.get, member.roles): for role in map(self.roles.get, member.roles):
@ -358,3 +414,6 @@ class Guild(SlottedModel, Permissible):
def create_ban(self, user, delete_message_days=0): def create_ban(self, user, delete_message_days=0):
self.client.api.guilds_bans_create(self.id, to_snowflake(user), delete_message_days) self.client.api.guilds_bans_create(self.id, to_snowflake(user), delete_message_days)
def create_channel(self, *args, **kwargs):
return self.client.api.guilds_channels_create(self.id, *args, **kwargs)

6
disco/types/invite.py

@ -1,4 +1,4 @@
from disco.types.base import SlottedModel, Field, lazy_datetime from disco.types.base import SlottedModel, Field, datetime
from disco.types.user import User from disco.types.user import User
from disco.types.guild import Guild from disco.types.guild import Guild
from disco.types.channel import Channel from disco.types.channel import Channel
@ -6,7 +6,7 @@ from disco.types.channel import Channel
class Invite(SlottedModel): class Invite(SlottedModel):
""" """
An invite object An invite object.
Attributes Attributes
---------- ----------
@ -37,7 +37,7 @@ class Invite(SlottedModel):
max_uses = Field(int) max_uses = Field(int)
uses = Field(int) uses = Field(int)
temporary = Field(bool) temporary = Field(bool)
created_at = Field(lazy_datetime) created_at = Field(datetime)
@classmethod @classmethod
def create(cls, channel, max_age=86400, max_uses=0, temporary=False, unique=False): def create(cls, channel, max_age=86400, max_uses=0, temporary=False, unique=False):

200
disco/types/message.py

@ -1,8 +1,14 @@
import re import re
import six
import functools
import unicodedata
from holster.enum import Enum from holster.enum import Enum
from disco.types.base import SlottedModel, Field, snowflake, text, lazy_datetime, dictof, listof, enum from disco.types.base import (
SlottedModel, Field, ListField, AutoDictField, snowflake, text,
datetime, enum
)
from disco.util.snowflake import to_snowflake from disco.util.snowflake import to_snowflake
from disco.util.functional import cached_property from disco.util.functional import cached_property
from disco.types.user import User from disco.types.user import User
@ -19,6 +25,31 @@ MessageType = Enum(
) )
class Emoji(SlottedModel):
id = Field(snowflake)
name = Field(text)
def __eq__(self, other):
if isinstance(other, Emoji):
return self.id == other.id and self.name == other.name
raise NotImplementedError
def to_string(self):
if self.id:
return '{}:{}'.format(self.name, self.id)
return self.name
class MessageReactionEmoji(Emoji):
pass
class MessageReaction(SlottedModel):
emoji = Field(MessageReactionEmoji)
count = Field(int)
me = Field(bool)
class MessageEmbedFooter(SlottedModel): class MessageEmbedFooter(SlottedModel):
text = Field(text) text = Field(text)
icon_url = Field(text) icon_url = Field(text)
@ -60,7 +91,7 @@ class MessageEmbedField(SlottedModel):
class MessageEmbed(SlottedModel): class MessageEmbed(SlottedModel):
""" """
Message embed object Message embed object.
Attributes Attributes
---------- ----------
@ -76,20 +107,38 @@ class MessageEmbed(SlottedModel):
title = Field(text) title = Field(text)
type = Field(str, default='rich') type = Field(str, default='rich')
description = Field(text) description = Field(text)
url = Field(str) url = Field(text)
timestamp = Field(lazy_datetime) timestamp = Field(datetime)
color = Field(int) color = Field(int)
footer = Field(MessageEmbedFooter) footer = Field(MessageEmbedFooter)
image = Field(MessageEmbedImage) image = Field(MessageEmbedImage)
thumbnail = Field(MessageEmbedThumbnail) thumbnail = Field(MessageEmbedThumbnail)
video = Field(MessageEmbedVideo) video = Field(MessageEmbedVideo)
author = Field(MessageEmbedAuthor) author = Field(MessageEmbedAuthor)
fields = Field(listof(MessageEmbedField)) fields = ListField(MessageEmbedField)
def set_footer(self, *args, **kwargs):
self.footer = MessageEmbedFooter(*args, **kwargs)
def set_image(self, *args, **kwargs):
self.image = MessageEmbedImage(*args, **kwargs)
def set_thumbnail(self, *args, **kwargs):
self.thumbnail = MessageEmbedThumbnail(*args, **kwargs)
def set_video(self, *args, **kwargs):
self.video = MessageEmbedVideo(*args, **kwargs)
def set_author(self, *args, **kwargs):
self.author = MessageEmbedAuthor(*args, **kwargs)
def add_field(self, *args, **kwargs):
self.fields.append(MessageEmbedField(*args, **kwargs))
class MessageAttachment(SlottedModel): class MessageAttachment(SlottedModel):
""" """
Message attachment object Message attachment object.
Attributes Attributes
---------- ----------
@ -110,8 +159,8 @@ class MessageAttachment(SlottedModel):
""" """
id = Field(str) id = Field(str)
filename = Field(text) filename = Field(text)
url = Field(str) url = Field(text)
proxy_url = Field(str) proxy_url = Field(text)
size = Field(int) size = Field(int)
height = Field(int) height = Field(int)
width = Field(int) width = Field(int)
@ -161,15 +210,16 @@ class Message(SlottedModel):
author = Field(User) author = Field(User)
content = Field(text) content = Field(text)
nonce = Field(snowflake) nonce = Field(snowflake)
timestamp = Field(lazy_datetime) timestamp = Field(datetime)
edited_timestamp = Field(lazy_datetime) edited_timestamp = Field(datetime)
tts = Field(bool) tts = Field(bool)
mention_everyone = Field(bool) mention_everyone = Field(bool)
pinned = Field(bool) pinned = Field(bool)
mentions = Field(dictof(User, key='id')) mentions = AutoDictField(User, 'id')
mention_roles = Field(listof(snowflake)) mention_roles = ListField(snowflake)
embeds = Field(listof(MessageEmbed)) embeds = ListField(MessageEmbed)
attachments = Field(dictof(MessageAttachment, key='id')) attachments = AutoDictField(MessageAttachment, 'id')
reactions = ListField(MessageReaction)
def __str__(self): def __str__(self):
return '<Message {} ({})>'.format(self.id, self.channel_id) return '<Message {} ({})>'.format(self.id, self.channel_id)
@ -213,7 +263,7 @@ class Message(SlottedModel):
def reply(self, *args, **kwargs): def reply(self, *args, **kwargs):
""" """
Reply to this message (proxys arguments to Reply to this message (proxys arguments to
:func:`disco.types.channel.Channel.send_message`) :func:`disco.types.channel.Channel.send_message`).
Returns Returns
------- -------
@ -222,9 +272,9 @@ class Message(SlottedModel):
""" """
return self.channel.send_message(*args, **kwargs) return self.channel.send_message(*args, **kwargs)
def edit(self, content): def edit(self, *args, **kwargs):
""" """
Edit this message Edit this message.
Args Args
---- ----
@ -236,7 +286,7 @@ class Message(SlottedModel):
:class:`Message` :class:`Message`
The edited message object. The edited message object.
""" """
return self.client.api.channels_messages_modify(self.channel_id, self.id, content) return self.client.api.channels_messages_modify(self.channel_id, self.id, *args, **kwargs)
def delete(self): def delete(self):
""" """
@ -249,6 +299,42 @@ class Message(SlottedModel):
""" """
return self.client.api.channels_messages_delete(self.channel_id, self.id) return self.client.api.channels_messages_delete(self.channel_id, self.id)
def get_reactors(self, emoji):
"""
Returns an list of users who reacted to this message with the given emoji.
Returns
-------
list(:class:`User`)
The users who reacted.
"""
return self.client.api.channels_messages_reactions_get(
self.channel_id,
self.id,
emoji
)
def create_reaction(self, emoji):
if isinstance(emoji, Emoji):
emoji = emoji.to_string()
self.client.api.channels_messages_reactions_create(
self.channel_id,
self.id,
emoji)
def delete_reaction(self, emoji, user=None):
if isinstance(emoji, Emoji):
emoji = emoji.to_string()
if user:
user = to_snowflake(user)
self.client.api.channels_messages_reactions_delete(
self.channel_id,
self.id,
emoji,
user)
def is_mentioned(self, entity): def is_mentioned(self, entity):
""" """
Returns Returns
@ -256,22 +342,37 @@ class Message(SlottedModel):
bool bool
Whether the give entity was mentioned. Whether the give entity was mentioned.
""" """
id = to_snowflake(entity) entity = to_snowflake(entity)
return id in self.mentions or id in self.mention_roles return entity in self.mentions or entity in self.mention_roles
@cached_property @cached_property
def without_mentions(self): def without_mentions(self, valid_only=False):
""" """
Returns Returns
------- -------
str str
the message contents with all valid mentions removed. the message contents with all mentions removed.
""" """
return self.replace_mentions( return self.replace_mentions(
lambda u: '', lambda u: '',
lambda r: '') lambda r: '',
lambda c: '',
nonexistant=not valid_only)
@cached_property
def with_proper_mentions(self):
def replace_user(u):
return u'@' + six.text_type(u)
def replace_role(r):
return u'@' + six.text_type(r)
def replace_channel(c):
return six.text_type(c)
def replace_mentions(self, user_replace, role_replace): return self.replace_mentions(replace_user, replace_role, replace_channel)
def replace_mentions(self, user_replace=None, role_replace=None, channel_replace=None, nonexistant=False):
""" """
Replaces user and role mentions with the result of a given lambda/function. Replaces user and role mentions with the result of a given lambda/function.
@ -289,39 +390,55 @@ class Message(SlottedModel):
str str
The message contents with all valid mentions replaced. The message contents with all valid mentions replaced.
""" """
if not self.mentions and not self.mention_roles: def replace(getter, func, match):
return oid = int(match.group(2))
obj = getter(oid)
if obj or nonexistant:
return func(obj or oid) or match.group(0)
return match.group(0)
content = self.content
if user_replace:
replace_user = functools.partial(replace, self.mentions.get, user_replace)
content = re.sub('(<@!?([0-9]+)>)', replace_user, content)
if role_replace:
replace_role = functools.partial(replace, lambda v: (self.guild and self.guild.roles.get(v)), role_replace)
content = re.sub('(<@&([0-9]+)>)', replace_role, content)
def replace(match): if channel_replace:
id = match.group(0) replace_channel = functools.partial(replace, self.client.state.channels.get, channel_replace)
if id in self.mention_roles: content = re.sub('(<#([0-9]+)>)', replace_channel, content)
return role_replace(id)
else:
return user_replace(self.mentions.get(id))
return re.sub('<@!?([0-9]+)>', replace, self.content) return content
class MessageTable(object): class MessageTable(object):
def __init__(self, sep=' | ', codeblock=True, header_break=True): def __init__(self, sep=' | ', codeblock=True, header_break=True, language=None):
self.header = [] self.header = []
self.entries = [] self.entries = []
self.size_index = {} self.size_index = {}
self.sep = sep self.sep = sep
self.codeblock = codeblock self.codeblock = codeblock
self.header_break = header_break self.header_break = header_break
self.language = language
def recalculate_size_index(self, cols): def recalculate_size_index(self, cols):
for idx, col in enumerate(cols): for idx, col in enumerate(cols):
if idx not in self.size_index or len(col) > self.size_index[idx]: size = len(unicodedata.normalize('NFC', col))
self.size_index[idx] = len(col) if idx not in self.size_index or size > self.size_index[idx]:
self.size_index[idx] = size
def set_header(self, *args): def set_header(self, *args):
args = list(map(six.text_type, args))
self.header = args self.header = args
self.recalculate_size_index(args) self.recalculate_size_index(args)
def add(self, *args): def add(self, *args):
args = list(map(str, args)) args = list(map(six.text_type, args))
self.entries.append(args) self.entries.append(args)
self.recalculate_size_index(args) self.recalculate_size_index(args)
@ -329,22 +446,23 @@ class MessageTable(object):
data = self.sep.lstrip() data = self.sep.lstrip()
for idx, col in enumerate(cols): for idx, col in enumerate(cols):
padding = ' ' * ((self.size_index[idx] - len(col))) padding = ' ' * (self.size_index[idx] - len(col))
data += col + padding + self.sep data += col + padding + self.sep
return data.rstrip() return data.rstrip()
def compile(self): def compile(self):
data = [] data = []
data.append(self.compile_one(self.header)) if self.header:
data = [self.compile_one(self.header)]
if self.header_break: if self.header and self.header_break:
data.append('-' * (sum(self.size_index.values()) + (len(self.header) * len(self.sep)) + 1)) data.append('-' * (sum(self.size_index.values()) + (len(self.header) * len(self.sep)) + 1))
for row in self.entries: for row in self.entries:
data.append(self.compile_one(row)) data.append(self.compile_one(row))
if self.codeblock: if self.codeblock:
return '```' + '\n'.join(data) + '```' return '```{}'.format(self.language if self.language else '') + '\n'.join(data) + '```'
return '\n'.join(data) return '\n'.join(data)

12
disco/types/permissions.py

@ -76,13 +76,13 @@ class PermissionValue(object):
return self.sub(other) return self.sub(other)
def __getattribute__(self, name): def __getattribute__(self, name):
if name in Permissions.attrs: if name in Permissions.keys_:
return (self.value & Permissions[name].value) == Permissions[name].value return (self.value & Permissions[name].value) == Permissions[name].value
else: else:
return object.__getattribute__(self, name) return object.__getattribute__(self, name)
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name not in Permissions.attrs: if name not in Permissions.keys_:
return super(PermissionValue, self).__setattr__(name, value) return super(PermissionValue, self).__setattr__(name, value)
if value: if value:
@ -90,9 +90,12 @@ class PermissionValue(object):
else: else:
self.value &= ~Permissions[name].value self.value &= ~Permissions[name].value
def __int__(self):
return self.value
def to_dict(self): def to_dict(self):
return { return {
k: getattr(self, k) for k in Permissions.attrs k: getattr(self, k) for k in Permissions.keys_
} }
@classmethod @classmethod
@ -107,6 +110,9 @@ class PermissionValue(object):
class Permissible(object): class Permissible(object):
__slots__ = [] __slots__ = []
def get_permissions(self):
raise NotImplementedError
def can(self, user, *args): def can(self, user, *args):
perms = self.get_permissions(user) perms = self.get_permissions(user)
return perms.administrator or perms.can(*args) return perms.administrator or perms.can(*args)

40
disco/types/user.py

@ -2,30 +2,54 @@ from holster.enum import Enum
from disco.types.base import SlottedModel, Field, snowflake, text, binary, with_equality, with_hash from disco.types.base import SlottedModel, Field, snowflake, text, binary, with_equality, with_hash
DefaultAvatars = Enum(
BLURPLE=0,
GREY=1,
GREEN=2,
ORANGE=3,
RED=4,
)
class User(SlottedModel, with_equality('id'), with_hash('id')): class User(SlottedModel, with_equality('id'), with_hash('id')):
id = Field(snowflake) id = Field(snowflake)
username = Field(text) username = Field(text)
avatar = Field(binary) avatar = Field(binary)
discriminator = Field(str) discriminator = Field(str)
bot = Field(bool) bot = Field(bool, default=False)
verified = Field(bool) verified = Field(bool)
email = Field(str) email = Field(str)
presence = Field(None) presence = Field(None)
def get_avatar_url(self, fmt='webp', size=1024):
if not self.avatar:
return 'https://cdn.discordapp.com/embed/avatars/{}.png'.format(self.default_avatar.value)
return 'https://cdn.discordapp.com/avatars/{}/{}.{}?size={}'.format(
self.id,
self.avatar,
fmt,
size
)
@property
def default_avatar(self):
return DefaultAvatars[int(self.discriminator) % len(DefaultAvatars.attrs)]
@property
def avatar_url(self):
return self.get_avatar_url()
@property @property
def mention(self): def mention(self):
return '<@{}>'.format(self.id) return '<@{}>'.format(self.id)
def to_string(self):
return '{}#{}'.format(self.username, self.discriminator)
def __str__(self): def __str__(self):
return '<User {} ({})>'.format(self.id, self.to_string()) return u'{}#{}'.format(self.username, str(self.discriminator).zfill(4))
def on_create(self): def __repr__(self):
self.client.state.users[self.id] = self return u'<User {} ({})>'.format(self.id, self)
GameType = Enum( GameType = Enum(
@ -49,6 +73,6 @@ class Game(SlottedModel):
class Presence(SlottedModel): class Presence(SlottedModel):
user = Field(User) user = Field(User, alias='user', ignore_dump=['presence'])
game = Field(Game) game = Field(Game)
status = Field(Status) status = Field(Status)

2
disco/types/voice.py

@ -17,7 +17,7 @@ class VoiceState(SlottedModel):
def guild(self): def guild(self):
return self.client.state.guilds.get(self.guild_id) return self.client.state.guilds.get(self.guild_id)
@cached_property @property
def channel(self): def channel(self):
return self.client.state.channels.get(self.channel_id) return self.client.state.channels.get(self.channel_id)

6
disco/types/webhook.py

@ -32,12 +32,14 @@ class Webhook(SlottedModel):
else: else:
return self.client.api.webhooks_modify(self.id, name, avatar) return self.client.api.webhooks_modify(self.id, name, avatar)
def execute(self, content=None, username=None, avatar_url=None, tts=False, file=None, embeds=[], wait=False): def execute(self, content=None, username=None, avatar_url=None, tts=False, fobj=None, embeds=[], wait=False):
# TODO: support file stuff properly
return self.client.api.webhooks_token_execute(self.id, self.token, { return self.client.api.webhooks_token_execute(self.id, self.token, {
'content': content, 'content': content,
'username': username, 'username': username,
'avatar_url': avatar_url, 'avatar_url': avatar_url,
'tts': tts, 'tts': tts,
'file': file, 'file': fobj,
'embeds': [i.to_dict() for i in embeds], 'embeds': [i.to_dict() for i in embeds],
}, wait) }, wait)

2
disco/util/config.py

@ -29,7 +29,7 @@ class Config(object):
return inst return inst
def from_prefix(self, prefix): def from_prefix(self, prefix):
prefix = prefix + '_' prefix += '_'
obj = {} obj = {}
for k, v in six.iteritems(self.__dict__): for k, v in six.iteritems(self.__dict__):

4
disco/util/hashmap.py

@ -45,12 +45,12 @@ class HashMap(UserDict):
def filter(self, predicate): def filter(self, predicate):
if not callable(predicate): if not callable(predicate):
raise TypeError('predicate must be callable') raise TypeError('predicate must be callable')
return filter(self.values(), predicate) return filter(predicate, self.values())
def map(self, predicate): def map(self, predicate):
if not callable(predicate): if not callable(predicate):
raise TypeError('predicate must be callable') raise TypeError('predicate must be callable')
return map(self.values(), predicate) return map(predicate, self.values())
class DefaultHashMap(defaultdict, HashMap): class DefaultHashMap(defaultdict, HashMap):

3
disco/util/limiter.py

@ -17,7 +17,8 @@ class SimpleLimiter(object):
gevent.sleep(self.reset_at - time.time()) gevent.sleep(self.reset_at - time.time())
self.count = 0 self.count = 0
self.reset_at = 0 self.reset_at = 0
self.event.set() if self.event:
self.event.set()
self.event = None self.event = None
def check(self): def check(self):

35
disco/util/logging.py

@ -3,15 +3,28 @@ from __future__ import absolute_import
import logging import logging
LEVEL_OVERRIDES = {
'requests': logging.WARNING
}
LOG_FORMAT = '[%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'
def setup_logging(**kwargs):
kwargs.setdefault('format', LOG_FORMAT)
logging.basicConfig(**kwargs)
for logger, level in LEVEL_OVERRIDES.items():
logging.getLogger(logger).setLevel(level)
class LoggingClass(object): class LoggingClass(object):
def __init__(self): __slots__ = ['_log']
self.log = logging.getLogger(self.__class__.__name__)
@property
def log_on_error(self, msg, f): def log(self):
def _f(*args, **kwargs): try:
try: return self._log
return f(*args, **kwargs) except AttributeError:
except: self._log = logging.getLogger(self.__class__.__name__)
self.log.exception(msg) return self._log
raise
return _f

36
disco/util/serializer.py

@ -1,3 +1,5 @@
import six
import types
class Serializer(object): class Serializer(object):
@ -36,3 +38,37 @@ class Serializer(object):
def dumps(cls, fmt, raw): def dumps(cls, fmt, raw):
_, dumps = getattr(cls, fmt)() _, dumps = getattr(cls, fmt)()
return dumps(raw) return dumps(raw)
def dump_cell(cell):
return cell.cell_contents
def load_cell(cell):
if six.PY3:
return (lambda y: cell).__closure__[0]
else:
return (lambda y: cell).func_closure[0]
def dump_function(func):
if six.PY3:
return (
func.__code__,
func.__name__,
func.__defaults__,
list(map(dump_cell, func.__closure__)) if func.__closure__ else [],
)
else:
return (
func.func_code,
func.func_name,
func.func_defaults,
list(map(dump_cell, func.func_closure)) if func.func_closure else [],
)
def load_function(args):
code, name, defaults, closure = args
closure = tuple(map(load_cell, closure))
return types.FunctionType(code, globals(), name, defaults, closure)

2
disco/util/snowflake.py

@ -17,7 +17,7 @@ def to_unix(snowflake):
def to_unix_ms(snowflake): def to_unix_ms(snowflake):
return ((int(snowflake) >> 22) + DISCORD_EPOCH) return (int(snowflake) >> 22) + DISCORD_EPOCH
def to_snowflake(i): def to_snowflake(i):

2
disco/util/token.py

@ -5,6 +5,6 @@ TOKEN_RE = re.compile(r'M\w{23}\.[\w-]{6}\..{27}')
def is_valid_token(token): def is_valid_token(token):
""" """
Validates a Discord authentication token, returning true if valid Validates a Discord authentication token, returning true if valid.
""" """
return bool(TOKEN_RE.match(token)) return bool(TOKEN_RE.match(token))

7
disco/voice/client.py

@ -106,6 +106,7 @@ class VoiceClient(LoggingClass):
self.endpoint = None self.endpoint = None
self.ssrc = None self.ssrc = None
self.port = None self.port = None
self.udp = None
self.update_listener = None self.update_listener = None
@ -158,7 +159,7 @@ class VoiceClient(LoggingClass):
} }
}) })
def on_voice_sdp(self, data): def on_voice_sdp(self, _):
# Toggle speaking state so clients learn of our SSRC # Toggle speaking state so clients learn of our SSRC
self.set_speaking(True) self.set_speaking(True)
self.set_speaking(False) self.set_speaking(False)
@ -187,11 +188,10 @@ class VoiceClient(LoggingClass):
def on_message(self, msg): def on_message(self, msg):
try: try:
data = self.encoder.decode(msg) data = self.encoder.decode(msg)
self.packets.emit(VoiceOPCode[data['op']], data['d'])
except: except:
self.log.exception('Failed to parse voice gateway message: ') self.log.exception('Failed to parse voice gateway message: ')
self.packets.emit(VoiceOPCode[data['op']], data['d'])
def on_error(self, err): def on_error(self, err):
# TODO # TODO
self.log.warning('Voice websocket error: {}'.format(err)) self.log.warning('Voice websocket error: {}'.format(err))
@ -205,6 +205,7 @@ class VoiceClient(LoggingClass):
}) })
def on_close(self, code, error): def on_close(self, code, error):
# TODO
self.log.warning('Voice websocket disconnected (%s, %s)', code, error) self.log.warning('Voice websocket disconnected (%s, %s)', code, error)
if self.state == VoiceState.CONNECTED: if self.state == VoiceState.CONNECTED:

127
disco/voice/opus.py

@ -1,8 +1,15 @@
import sys import sys
import array import array
import gevent
import ctypes import ctypes
import ctypes.util import ctypes.util
try:
from cStringIO import cStringIO as StringIO
except:
from StringIO import StringIO
from gevent.queue import Queue
from holster.enum import Enum from holster.enum import Enum
from disco.util.logging import LoggingClass from disco.util.logging import LoggingClass
@ -43,12 +50,12 @@ class BaseOpus(LoggingClass):
for name, item in methods.items(): for name, item in methods.items():
func = getattr(self.lib, name) func = getattr(self.lib, name)
if item[1]: if item[0]:
func.argtypes = item[1] func.argtypes = item[0]
func.restype = item[2] func.restype = item[1]
setattr(self, name.replace('opus_', ''), func) setattr(self, name, func)
@staticmethod @staticmethod
def find_library(): def find_library():
@ -83,7 +90,7 @@ class OpusEncoder(BaseOpus):
} }
def __init__(self, sampling, channels, application=Application.AUDIO, library_path=None): def __init__(self, sampling, channels, application=Application.AUDIO, library_path=None):
super(OpusDecoder, self).__init__(library_path) super(OpusEncoder, self).__init__(library_path)
self.sampling_rate = sampling self.sampling_rate = sampling
self.channels = channels self.channels = channels
self.application = application self.application = application
@ -94,10 +101,32 @@ class OpusEncoder(BaseOpus):
self.frame_size = self.samples_per_frame * self.sample_size self.frame_size = self.samples_per_frame * self.sample_size
self.inst = self.create() self.inst = self.create()
self.set_bitrate(128)
self.set_fec(True)
self.set_expected_packet_loss_percent(0.15)
def set_bitrate(self, kbps):
kbps = min(128, max(16, int(kbps)))
ret = self.opus_encoder_ctl(self.inst, int(Control.SET_BITRATE), kbps * 1024)
if ret < 0:
raise Exception('Failed to set bitrate to {}: {}'.format(kbps, ret))
def set_fec(self, value):
ret = self.opus_encoder_ctl(self.inst, int(Control.SET_FEC), int(value))
if ret < 0:
raise Exception('Failed to set FEC to {}: {}'.format(value, ret))
def set_expected_packet_loss_percent(self, perc):
ret = self.opus_encoder_ctl(self.inst, int(Control.SET_PLP), min(100, max(0, int(perc * 100))))
if ret < 0:
raise Exception('Failed to set PLP to {}: {}'.format(perc, ret))
def create(self): def create(self):
ret = ctypes.c_int() ret = ctypes.c_int()
result = self.encoder_create(self.sampling_rate, self.channels, self.application.value, ctypes.byref(ret)) result = self.opus_encoder_create(self.sampling_rate, self.channels, self.application.value, ctypes.byref(ret))
if ret.value != 0: if ret.value != 0:
raise Exception('Failed to create opus encoder: {}'.format(ret.value)) raise Exception('Failed to create opus encoder: {}'.format(ret.value))
@ -106,7 +135,7 @@ class OpusEncoder(BaseOpus):
def __del__(self): def __del__(self):
if self.inst: if self.inst:
self.encoder_destroy(self.inst) self.opus_encoder_destroy(self.inst)
self.inst = None self.inst = None
def encode(self, pcm, frame_size): def encode(self, pcm, frame_size):
@ -114,12 +143,92 @@ class OpusEncoder(BaseOpus):
pcm = ctypes.cast(pcm, c_int16_ptr) pcm = ctypes.cast(pcm, c_int16_ptr)
data = (ctypes.c_char * max_data_bytes)() data = (ctypes.c_char * max_data_bytes)()
ret = self.encode(self.inst, pcm, frame_size, data, max_data_bytes) ret = self.opus_encode(self.inst, pcm, frame_size, data, max_data_bytes)
if ret < 0: if ret < 0:
raise Exception('Failed to encode: {}'.format(ret)) raise Exception('Failed to encode: {}'.format(ret))
return array.array('b', data[:ret]).tobytes() # TODO: py3
return array.array('b', data[:ret]).tostring()
class OpusDecoder(BaseOpus): class OpusDecoder(BaseOpus):
pass pass
class BufferedOpusEncoder(OpusEncoder):
def __init__(self, data, *args, **kwargs):
self.data = StringIO(data)
self.frames = Queue(kwargs.pop('queue_size', 4096))
super(BufferedOpusEncoder, self).__init__(*args, **kwargs)
gevent.spawn(self._encoder_loop)
def _encoder_loop(self):
while self.data:
raw = self.data.read(self.frame_size)
if len(raw) < self.frame_size:
break
self.frames.put(self.encode(raw, self.samples_per_frame))
gevent.idle()
self.data = None
def have_frame(self):
return self.data or not self.frames.empty()
def next_frame(self):
return self.frames.get()
class GIPCBufferedOpusEncoder(OpusEncoder):
FIN = 1
def __init__(self, data, *args, **kwargs):
import gipc
self.data = StringIO(data)
self.parent_pipe, self.child_pipe = gipc.pipe(duplex=True)
self.frames = Queue(kwargs.pop('queue_size', 4096))
super(GIPCBufferedOpusEncoder, self).__init__(*args, **kwargs)
gipc.start_process(target=self._encoder_loop, args=(self.child_pipe, (args, kwargs)))
gevent.spawn(self._writer)
gevent.spawn(self._reader)
def _reader(self):
while True:
data = self.parent_pipe.get()
if data == self.FIN:
return
self.frames.put(data)
self.parent_pipe = None
def _writer(self):
while self.data:
raw = self.data.read(self.frame_size)
if len(raw) < self.frame_size:
break
self.parent_pipe.put(raw)
gevent.idle()
self.parent_pipe.put(self.FIN)
def have_frame(self):
return self.parent_pipe
def next_frame(self):
return self.frames.get()
@classmethod
def _encoder_loop(cls, pipe, (args, kwargs)):
encoder = OpusEncoder(*args, **kwargs)
while True:
data = pipe.get()
if data == cls.FIN:
pipe.put(cls.FIN)
return
pipe.put(encoder.encode(data, encoder.samples_per_frame))

49
disco/voice/player.py

@ -1,18 +1,54 @@
import time
import gevent import gevent
import struct import struct
import time import subprocess
from six.moves import queue from six.moves import queue
from disco.voice.client import VoiceState from disco.voice.client import VoiceState
from disco.voice.opus import BufferedOpusEncoder, GIPCBufferedOpusEncoder
class BaseFFmpegPlayable(object):
def __init__(self, source='-', command='avconv', sampling_rate=48000, channels=2, **kwargs):
args = [command, '-i', source, '-f', 's16le', '-ar', str(sampling_rate), '-ac', str(channels), '-loglevel', 'warning', 'pipe:1']
self.proc = subprocess.Popen(args, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
data, _ = self.proc.communicate()
super(BaseFFmpegPlayable, self).__init__(data, sampling_rate, channels, **kwargs)
class FFmpegPlayable(BaseFFmpegPlayable, BufferedOpusEncoder):
pass
class GIPCFFmpegPlayable(BaseFFmpegPlayable, GIPCBufferedOpusEncoder):
pass
class OpusItem(object): def create_youtube_dl_playable(url, cls=FFmpegPlayable, *args, **kwargs):
def __init__(self, frame_length=20, channels=2): import youtube_dl
ydl = youtube_dl.YoutubeDL({'format': 'webm[abr>0]/bestaudio/best'})
info = ydl.extract_info(url, download=False)
if 'entries' in info:
info = info['entries'][0]
return cls(info['url'], *args, **kwargs), info
class OpusPlayable(object):
"""
Represents a Playable item which is a cached set of Opus-encoded bytes.
"""
def __init__(self, sampling_rate=48000, frame_length=20, channels=2):
self.frames = [] self.frames = []
self.idx = 0 self.idx = 0
self.frame_length = 20
self.sampling_rate = sampling_rate
self.frame_length = frame_length self.frame_length = frame_length
self.channels = channels self.channels = channels
self.sample_size = int(self.sampling_rate / 1000 * self.frame_length)
@classmethod @classmethod
def from_raw_file(cls, path): def from_raw_file(cls, path):
@ -58,6 +94,7 @@ class Player(object):
def play(self, item): def play(self, item):
start = time.time() start = time.time()
loops = 0 loops = 0
timestamp = 0
while True: while True:
loops += 1 loops += 1
@ -76,13 +113,15 @@ class Player(object):
if not item.have_frame(): if not item.have_frame():
return return
self.client.send_frame(item.next_frame()) self.client.send_frame(item.next_frame(), loops, timestamp)
timestamp += item.samples_per_frame
next_time = start + 0.02 * loops next_time = start + 0.02 * loops
delay = max(0, 0.02 + (next_time - time.time())) delay = max(0, 0.02 + (next_time - time.time()))
gevent.sleep(delay) gevent.sleep(delay)
def run(self): def run(self):
self.client.set_speaking(True) self.client.set_speaking(True)
while self.playing: while self.playing:
self.play(self.queue.get()) self.play(self.queue.get())
@ -90,4 +129,6 @@ class Player(object):
self.playing = False self.playing = False
self.complete.set() self.complete.set()
return return
self.client.set_speaking(False) self.client.set_speaking(False)
self.disconnect()

9
examples/music.py

@ -1,15 +1,16 @@
from disco.bot import Plugin from disco.bot import Plugin
from disco.bot.command import CommandError from disco.bot.command import CommandError
from disco.voice.client import Player, OpusItem, VoiceException from disco.voice.player import Player, create_youtube_dl_playable
from disco.voice.client import VoiceException
def download(url): def download(url):
return OpusItem.from_raw_file('test.dca') return create_youtube_dl_playable(url)[0]
class MusicPlugin(Plugin): class MusicPlugin(Plugin):
def load(self): def load(self, ctx):
super(MusicPlugin, self).load() super(MusicPlugin, self).load(ctx)
self.guilds = {} self.guilds = {}
@Plugin.command('join') @Plugin.command('join')

2
requirements.txt

@ -1,5 +1,5 @@
gevent==1.1.2 gevent==1.1.2
holster==1.0.7 holster==1.0.11
inflection==0.3.1 inflection==0.3.1
requests==2.11.1 requests==2.11.1
six==1.10.0 six==1.10.0

Loading…
Cancel
Save