diff --git a/disco/api/client.py b/disco/api/client.py index 2f2533e..56423a5 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -2,13 +2,14 @@ import six import json import warnings -from disco.api.http import Routes, HTTPClient +from six.moves.urllib.parse import quote + +from disco.api.http import Routes, HTTPClient, to_bytes from disco.util.logging import LoggingClass from disco.util.sanitize import S - from disco.types.user import User from disco.types.message import Message -from disco.types.guild import Guild, GuildMember, GuildBan, Role, GuildEmoji +from disco.types.guild import Guild, GuildMember, GuildBan, Role, GuildEmoji, AuditLogEntry from disco.types.channel import Channel from disco.types.invite import Invite from disco.types.webhook import Webhook @@ -24,6 +25,10 @@ def optional(**kwargs): return {k: v for k, v in six.iteritems(kwargs) if v is not None} +def _reason_header(value): + return optional(**{'X-Audit-Log-Reason': quote(to_bytes(value))}) + + class APIClient(LoggingClass): """ An abstraction over a :class:`disco.api.http.HTTPClient`, which composes @@ -65,12 +70,19 @@ class APIClient(LoggingClass): r = self.http(Routes.CHANNELS_GET, dict(channel=channel)) return Channel.create(self.client, r.json()) - def channels_modify(self, channel, **kwargs): - r = self.http(Routes.CHANNELS_MODIFY, dict(channel=channel), json=kwargs) + def channels_modify(self, channel, reason=None, **kwargs): + r = self.http( + Routes.CHANNELS_MODIFY, + dict(channel=channel), + json=kwargs, + headers=_reason_header(reason)) return Channel.create(self.client, r.json()) - def channels_delete(self, channel): - r = self.http(Routes.CHANNELS_DELETE, dict(channel=channel)) + def channels_delete(self, channel, reason=None): + r = self.http( + Routes.CHANNELS_DELETE, + dict(channel=channel), + headers=_reason_header(reason)) return Channel.create(self.client, r.json()) def channels_typing(self, channel): @@ -175,27 +187,27 @@ class APIClient(LoggingClass): self.http(route, obj) - def channels_permissions_modify(self, channel, permission, allow, deny, typ): + def channels_permissions_modify(self, channel, permission, allow, deny, typ, reason=None): self.http(Routes.CHANNELS_PERMISSIONS_MODIFY, dict(channel=channel, permission=permission), json={ 'allow': allow, 'deny': deny, 'type': typ, - }) + }, headers=_reason_header(reason)) - def channels_permissions_delete(self, channel, permission): - self.http(Routes.CHANNELS_PERMISSIONS_DELETE, dict(channel=channel, permission=permission)) + def channels_permissions_delete(self, channel, permission, reason=None): + self.http(Routes.CHANNELS_PERMISSIONS_DELETE, dict(channel=channel, permission=permission), headers=_reason_header(reason)) def channels_invites_list(self, channel): r = self.http(Routes.CHANNELS_INVITES_LIST, dict(channel=channel)) return Invite.create_map(self.client, r.json()) - def channels_invites_create(self, channel, max_age=86400, max_uses=0, temporary=False, unique=False): + def channels_invites_create(self, channel, max_age=86400, max_uses=0, temporary=False, unique=False, reason=None): r = self.http(Routes.CHANNELS_INVITES_CREATE, dict(channel=channel), json={ 'max_age': max_age, 'max_uses': max_uses, 'temporary': temporary, 'unique': unique - }) + }, headers=_reason_header(reason)) return Invite.create(self.client, r.json()) def channels_pins_list(self, channel): @@ -223,8 +235,8 @@ class APIClient(LoggingClass): r = self.http(Routes.GUILDS_GET, dict(guild=guild)) return Guild.create(self.client, r.json()) - def guilds_modify(self, guild, **kwargs): - r = self.http(Routes.GUILDS_MODIFY, dict(guild=guild), json=kwargs) + def guilds_modify(self, guild, reason=None, **kwargs): + r = self.http(Routes.GUILDS_MODIFY, dict(guild=guild), json=kwargs, headers=_reason_header(reason)) return Guild.create(self.client, r.json()) def guilds_delete(self, guild): @@ -235,7 +247,7 @@ class APIClient(LoggingClass): r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild)) return Channel.create_hash(self.client, 'id', r.json(), guild_id=guild) - def guilds_channels_create(self, guild, name, channel_type, bitrate=None, user_limit=None, permission_overwrites=[]): + def guilds_channels_create(self, guild, name, channel_type, bitrate=None, user_limit=None, permission_overwrites=[], reason=None): payload = { 'name': name, 'channel_type': channel_type, @@ -254,14 +266,18 @@ class APIClient(LoggingClass): # TODO: better error here? raise Exception('Invalid channel type: {}'.format(channel_type)) - r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=payload) + r = self.http( + Routes.GUILDS_CHANNELS_CREATE, + dict(guild=guild), + json=payload, + headers=_reason_header(reason)) return Channel.create(self.client, r.json(), guild_id=guild) - def guilds_channels_modify(self, guild, channel, position): + def guilds_channels_modify(self, guild, channel, position, reason=None): self.http(Routes.GUILDS_CHANNELS_MODIFY, dict(guild=guild), json={ 'id': channel, 'position': position, - }) + }, headers=_reason_header(reason)) def guilds_members_list(self, guild, limit=1000, after=None): r = self.http(Routes.GUILDS_MEMBERS_LIST, dict(guild=guild), params=optional( @@ -274,51 +290,68 @@ class APIClient(LoggingClass): r = self.http(Routes.GUILDS_MEMBERS_GET, dict(guild=guild, member=member)) return GuildMember.create(self.client, r.json(), guild_id=guild) - def guilds_members_modify(self, guild, member, **kwargs): - self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=optional(**kwargs)) - - def guilds_members_roles_add(self, guild, member, role): - self.http(Routes.GUILDS_MEMBERS_ROLES_ADD, dict(guild=guild, member=member, role=role)) - - def guilds_members_roles_remove(self, guild, member, role): - self.http(Routes.GUILDS_MEMBERS_ROLES_REMOVE, dict(guild=guild, member=member, role=role)) + def guilds_members_modify(self, guild, member, reason=None, **kwargs): + self.http( + Routes.GUILDS_MEMBERS_MODIFY, + dict(guild=guild, member=member), + json=optional(**kwargs), + headers=_reason_header(reason)) + + def guilds_members_roles_add(self, guild, member, role, reason=None): + self.http( + Routes.GUILDS_MEMBERS_ROLES_ADD, + dict(guild=guild, member=member, role=role), + headers=_reason_header(reason)) + + def guilds_members_roles_remove(self, guild, member, role, reason=None): + self.http( + Routes.GUILDS_MEMBERS_ROLES_REMOVE, + dict(guild=guild, member=member, role=role), + headers=_reason_header(reason)) def guilds_members_me_nick(self, guild, nick): self.http(Routes.GUILDS_MEMBERS_ME_NICK, dict(guild=guild), json={'nick': nick}) - def guilds_members_kick(self, guild, member): - self.http(Routes.GUILDS_MEMBERS_KICK, dict(guild=guild, member=member)) + def guilds_members_kick(self, guild, member, reason=None): + self.http(Routes.GUILDS_MEMBERS_KICK, dict(guild=guild, member=member), headers=_reason_header(reason)) def guilds_bans_list(self, guild): r = self.http(Routes.GUILDS_BANS_LIST, dict(guild=guild)) return GuildBan.create_hash(self.client, 'user.id', r.json()) - def guilds_bans_create(self, guild, user, delete_message_days): + def guilds_bans_create(self, guild, user, delete_message_days=0, reason=None): self.http(Routes.GUILDS_BANS_CREATE, dict(guild=guild, user=user), params={ 'delete-message-days': delete_message_days, - }) + }, headers=_reason_header(reason)) - def guilds_bans_delete(self, guild, user): - self.http(Routes.GUILDS_BANS_DELETE, dict(guild=guild, user=user)) + def guilds_bans_delete(self, guild, user, reason=None): + self.http( + Routes.GUILDS_BANS_DELETE, + dict(guild=guild, user=user), + headers=_reason_header(reason)) def guilds_roles_list(self, guild): r = self.http(Routes.GUILDS_ROLES_LIST, dict(guild=guild)) return Role.create_map(self.client, r.json(), guild_id=guild) - def guilds_roles_create(self, guild): - r = self.http(Routes.GUILDS_ROLES_CREATE, dict(guild=guild)) + def guilds_roles_create(self, guild, reason=None): + r = self.http(Routes.GUILDS_ROLES_CREATE, dict(guild=guild), headers=_reason_header(reason)) return Role.create(self.client, r.json(), guild_id=guild) - def guilds_roles_modify_batch(self, guild, roles): - r = self.http(Routes.GUILDS_ROLES_MODIFY_BATCH, dict(guild=guild), json=roles) + def guilds_roles_modify_batch(self, guild, roles, reason=None): + r = self.http(Routes.GUILDS_ROLES_MODIFY_BATCH, dict(guild=guild), json=roles, headers=_reason_header(reason)) return Role.create_map(self.client, r.json(), guild_id=guild) - def guilds_roles_modify(self, guild, role, **kwargs): - r = self.http(Routes.GUILDS_ROLES_MODIFY, dict(guild=guild, role=role), json=kwargs) + def guilds_roles_modify(self, guild, role, reason=None, **kwargs): + r = self.http( + Routes.GUILDS_ROLES_MODIFY, + dict(guild=guild, role=role), + json=kwargs, + headers=_reason_header(reason)) return Role.create(self.client, r.json(), guild_id=guild) - def guilds_roles_delete(self, guild, role): - self.http(Routes.GUILDS_ROLES_DELETE, dict(guild=guild, role=role)) + def guilds_roles_delete(self, guild, role, reason=None): + self.http(Routes.GUILDS_ROLES_DELETE, dict(guild=guild, role=role), headers=_reason_header(reason)) def guilds_invites_list(self, guild): r = self.http(Routes.GUILDS_INVITES_LIST, dict(guild=guild)) @@ -332,16 +365,41 @@ class APIClient(LoggingClass): r = self.http(Routes.GUILDS_EMOJIS_LIST, dict(guild=guild)) return GuildEmoji.create_map(self.client, r.json()) - def guilds_emojis_create(self, guild, **kwargs): - r = self.http(Routes.GUILDS_EMOJIS_CREATE, dict(guild=guild), json=kwargs) + def guilds_emojis_create(self, guild, reason=None, **kwargs): + r = self.http( + Routes.GUILDS_EMOJIS_CREATE, + dict(guild=guild), + json=kwargs, + headers=_reason_header(reason)) return GuildEmoji.create(self.client, r.json(), guild_id=guild) - def guilds_emojis_modify(self, guild, emoji, **kwargs): - r = self.http(Routes.GUILDS_EMOJIS_MODIFY, dict(guild=guild, emoji=emoji), json=kwargs) + def guilds_emojis_modify(self, guild, emoji, reason=None, **kwargs): + r = self.http( + Routes.GUILDS_EMOJIS_MODIFY, + dict(guild=guild, emoji=emoji), + json=kwargs, + headers=_reason_header(reason)) return GuildEmoji.create(self.client, r.json(), guild_id=guild) - def guilds_emojis_delete(self, guild, emoji): - self.http(Routes.GUILDS_EMOJIS_DELETE, dict(guild=guild, emoji=emoji)) + def guilds_emojis_delete(self, guild, emoji, reason=None): + self.http( + Routes.GUILDS_EMOJIS_DELETE, + dict(guild=guild, emoji=emoji), + headers=_reason_header(reason)) + + def guilds_auditlogs_list(self, guild, before=None, user_id=None, action_type=None, limit=50): + r = self.http(Routes.GUILDS_AUDITLOGS_LIST, dict(guild=guild), params=optional( + before=before, + user_id=user_id, + action_type=int(action_type) if action_type else None, + limit=limit, + )) + + data = r.json() + + users = User.create_hash(self.client, 'id', data['users']) + webhooks = Webhook.create_hash(self.client, 'id', data['webhooks']) + return AuditLogEntry.create_map(self.client, r.json()['audit_log_entries'], users, webhooks, guild_id=guild) def users_me_get(self): return User.create(self.client, self.http(Routes.USERS_ME_GET).json()) @@ -363,23 +421,23 @@ class APIClient(LoggingClass): r = self.http(Routes.INVITES_GET, dict(invite=invite)) return Invite.create(self.client, r.json()) - def invites_delete(self, invite): - r = self.http(Routes.INVITES_DELETE, dict(invite=invite)) + def invites_delete(self, invite, reason=None): + r = self.http(Routes.INVITES_DELETE, dict(invite=invite), headers=_reason_header(reason)) return Invite.create(self.client, r.json()) def webhooks_get(self, webhook): r = self.http(Routes.WEBHOOKS_GET, dict(webhook=webhook)) return Webhook.create(self.client, r.json()) - def webhooks_modify(self, webhook, name=None, avatar=None): + def webhooks_modify(self, webhook, name=None, avatar=None, reason=None): r = self.http(Routes.WEBHOOKS_MODIFY, dict(webhook=webhook), json=optional( name=name, avatar=avatar, - )) + ), headers=_reason_header(reason)) return Webhook.create(self.client, r.json()) - def webhooks_delete(self, webhook): - self.http(Routes.WEBHOOKS_DELETE, dict(webhook=webhook)) + def webhooks_delete(self, webhook, reason=None): + self.http(Routes.WEBHOOKS_DELETE, dict(webhook=webhook), headers=_reason_header(reason)) def webhooks_token_get(self, webhook, token): r = self.http(Routes.WEBHOOKS_TOKEN_GET, dict(webhook=webhook, token=token)) diff --git a/disco/api/http.py b/disco/api/http.py index 732f88c..cc592ca 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -103,6 +103,7 @@ class Routes(object): GUILDS_EMOJIS_CREATE = (HTTPMethod.POST, GUILDS + '/emojis') GUILDS_EMOJIS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/emojis/{emoji}') GUILDS_EMOJIS_DELETE = (HTTPMethod.DELETE, GUILDS + '/emojis/{emoji}') + GUILDS_AUDITLOGS_LIST = (HTTPMethod.GET, GUILDS + '/audit-logs') # Users USERS = '/users' @@ -251,6 +252,8 @@ class HTTPClient(LoggingClass): # Possibly wait if we're rate limited self.limiter.check(bucket) + self.log.debug('KW: %s', kwargs) + # Make the actual request url = self.BASE_URL + route[1].format(**args) self.log.info('%s %s (%s)', route[0].value, url, kwargs.get('params')) @@ -263,6 +266,7 @@ class HTTPClient(LoggingClass): if r.status_code < 400: return r elif r.status_code != 429 and 400 <= r.status_code < 500: + self.log.warning('Request failed with code %s: %s', r.status_code, r.content) raise APIException(r) else: if r.status_code == 429: diff --git a/disco/cli.py b/disco/cli.py index 7e54333..854a344 100644 --- a/disco/cli.py +++ b/disco/cli.py @@ -50,6 +50,9 @@ def disco_main(run=False): else: config = ClientConfig() + config.manhole_enable = args.manhole + config.manhole_bind = args.manhole_bind.split(':', 1) + for k, v in six.iteritems(vars(args)): if hasattr(config, k) and v is not None: setattr(config, k, v) diff --git a/disco/types/base.py b/disco/types/base.py index f219b9b..54cbcfd 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -349,8 +349,8 @@ class Model(six.with_metaclass(ModelMeta, Chainable)): return inst @classmethod - def create_map(cls, client, data, **kwargs): - return list(map(functools.partial(cls.create, client, **kwargs), data)) + def create_map(cls, client, data, *args, **kwargs): + return list(map(functools.partial(cls.create, client, *args, **kwargs), data)) @classmethod def create_hash(cls, client, key, data, **kwargs): diff --git a/disco/types/channel.py b/disco/types/channel.py index 853ca40..a8d7936 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -369,9 +369,9 @@ class Channel(SlottedModel, Permissible): for msg in messages: self.delete_message(msg) - def delete(self): + def delete(self, **kwargs): assert (self.is_dm or self.guild.can(self.client.state.me, Permissions.MANAGE_GUILD)), 'Invalid Permissions' - self.client.api.channels_delete(self.id) + self.client.api.channels_delete(self.id, **kwargs) def close(self): """ diff --git a/disco/types/guild.py b/disco/types/guild.py index 7fb5308..fe2b4a3 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -4,10 +4,11 @@ from holster.enum import Enum from disco.gateway.packets import OPCode from disco.api.http import APIException +from disco.util.paginator import Paginator from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property from disco.types.base import ( - SlottedModel, Field, ListField, AutoDictField, snowflake, text, enum, datetime + SlottedModel, Field, ListField, AutoDictField, DictField, snowflake, text, enum, datetime ) from disco.types.user import User from disco.types.voice import VoiceState @@ -436,8 +437,8 @@ class Guild(SlottedModel, Permissible): def delete_ban(self, user): self.client.api.guilds_bans_delete(self.id, to_snowflake(user)) - def create_ban(self, user, delete_message_days=0): - self.client.api.guilds_bans_create(self.id, to_snowflake(user), delete_message_days) + def create_ban(self, user, *args, **kwargs): + self.client.api.guilds_bans_create(self.id, to_snowflake(user), *args, **kwargs) def create_channel(self, *args, **kwargs): return self.client.api.guilds_channels_create(self.id, *args, **kwargs) @@ -470,3 +471,148 @@ class Guild(SlottedModel, Permissible): @property def splash_url(self): return self.get_splash_url() + + @property + def audit_log(self): + return Paginator( + self.client.api.guilds_auditlogs_list, + 'before', + self.id, + ) + + def get_audit_log_entries(self, *args, **kwargs): + return self.client.api.guilds_auditlogs_list(self.id, *args, **kwargs) + + +AuditLogActionTypes = Enum( + GUILD_UPDATE=1, + CHANNEL_CREATE=10, + CHANNEL_UPDATE=11, + CHANNEL_DELETE=12, + CHANNEL_OVERWRITE_CREATE=13, + CHANNEL_OVERWRITE_UPDATE=14, + CHANNEL_OVERWRITE_DELETE=15, + MEMBER_KICK=20, + MEMBER_PRUNE=21, + MEMBER_BAN_ADD=22, + MEMBER_BAN_REMOVE=23, + MEMBER_UPDATE=24, + MEMBER_ROLE_UPDATE=25, + ROLE_CREATE=30, + ROLE_UPDATE=31, + ROLE_DELETE=32, + INVITE_CREATE=40, + INVITE_UPDATE=41, + INVITE_DELETE=42, + WEBHOOK_CREATE=50, + WEBHOOK_UPDATE=51, + WEBHOOK_DELETE=52, + EMOJI_CREATE=60, + EMOJI_UPDATE=61, + EMOJI_DELETE=62, + MESSAGE_DELETE=72, +) + + +GUILD_ACTIONS = ( + AuditLogActionTypes.GUILD_UPDATE, +) + +CHANNEL_ACTIONS = ( + AuditLogActionTypes.CHANNEL_CREATE, + AuditLogActionTypes.CHANNEL_UPDATE, + AuditLogActionTypes.CHANNEL_DELETE, + AuditLogActionTypes.CHANNEL_OVERWRITE_CREATE, + AuditLogActionTypes.CHANNEL_OVERWRITE_UPDATE, + AuditLogActionTypes.CHANNEL_OVERWRITE_DELETE, +) + +MEMBER_ACTIONS = ( + AuditLogActionTypes.MEMBER_KICK, + AuditLogActionTypes.MEMBER_PRUNE, + AuditLogActionTypes.MEMBER_BAN_ADD, + AuditLogActionTypes.MEMBER_BAN_REMOVE, + AuditLogActionTypes.MEMBER_UPDATE, + AuditLogActionTypes.MEMBER_ROLE_UPDATE, +) + +ROLE_ACTIONS = ( + AuditLogActionTypes.ROLE_CREATE, + AuditLogActionTypes.ROLE_UPDATE, + AuditLogActionTypes.ROLE_DELETE, +) + +INVITE_ACTIONS = ( + AuditLogActionTypes.INVITE_CREATE, + AuditLogActionTypes.INVITE_UPDATE, + AuditLogActionTypes.INVITE_DELETE, +) + +WEBHOOK_ACTIONS = ( + AuditLogActionTypes.WEBHOOK_CREATE, + AuditLogActionTypes.WEBHOOK_UPDATE, + AuditLogActionTypes.WEBHOOK_DELETE, +) + +EMOJI_ACTIONS = ( + AuditLogActionTypes.EMOJI_CREATE, + AuditLogActionTypes.EMOJI_UPDATE, + AuditLogActionTypes.EMOJI_DELETE, +) + +MESSAGE_ACTIONS = ( + AuditLogActionTypes.MESSAGE_DELETE, +) + + +class AuditLogObjectChange(SlottedModel): + key = Field(text) + new_value = Field(text) + old_value = Field(text) + + +class AuditLogEntry(SlottedModel): + id = Field(snowflake) + guild_id = Field(snowflake) + user_id = Field(snowflake) + target_id = Field(snowflake) + action_type = Field(enum(AuditLogActionTypes)) + changes = ListField(AuditLogObjectChange) + options = DictField(text, text) + reason = Field(text) + + _cached_target = Field(None) + + @classmethod + def create(cls, client, users, webhooks, data, **kwargs): + self = super(SlottedModel, cls).create(client, data, **kwargs) + + if self.action_type in MEMBER_ACTIONS: + self._cached_target = users[self.target_id] + elif self.action_type in WEBHOOK_ACTIONS: + self._cached_target = webhooks[self.target_id] + + return self + + @cached_property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + + @cached_property + def user(self): + return self.client.state.users.get(self.user_id) + + @cached_property + def target(self): + if self.action_type in GUILD_ACTIONS: + return self.guild + elif self.action_type in CHANNEL_ACTIONS: + return self.guild.channels.get(self.target_id) + elif self.action_type in MEMBER_ACTIONS: + return self._cached_target or self.state.users.get(self.target_id) + elif self.action_type in ROLE_ACTIONS: + return self.guild.roles.get(self.target_id) + elif self.action_type in WEBHOOK_ACTIONS: + return self._cached_target + elif self.action_type in EMOJI_ACTIONS: + return self.guild.emojis.get(self.target_id) diff --git a/disco/types/invite.py b/disco/types/invite.py index 7f22a57..1b458ef 100644 --- a/disco/types/invite.py +++ b/disco/types/invite.py @@ -40,13 +40,8 @@ class Invite(SlottedModel): created_at = Field(datetime) @classmethod - def create_for_channel(cls, channel, max_age=86400, max_uses=0, temporary=False, unique=False): - return channel.client.api.channels_invites_create( - channel.id, - max_age=max_age, - max_uses=max_uses, - temporary=temporary, - unique=unique) + def create_for_channel(cls, channel, *args, **kwargs): + return channel.client.api.channels_invites_create(channel.id, *args, **kwargs) - def delete(self): - self.client.api.invites_delete(self.code) + def delete(self, *args, **kwargs): + self.client.api.invites_delete(self.code, *args, **kwargs) diff --git a/disco/types/message.py b/disco/types/message.py index 07b44af..3077b32 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -315,6 +315,7 @@ class Message(SlottedModel): return Paginator( self.client.api.channels_messages_reactions_get, + 'after', self.channel_id, self.id, emoji, diff --git a/disco/util/functional.py b/disco/util/functional.py index 29921a3..d52d4e9 100644 --- a/disco/util/functional.py +++ b/disco/util/functional.py @@ -56,6 +56,9 @@ class CachedSlotProperty(object): self.function = function self.__doc__ = getattr(function, '__doc__') + def set(self, value): + setattr(self.stored_name, value) + def __get__(self, instance, owner): if instance is None: return self diff --git a/disco/util/paginator.py b/disco/util/paginator.py index 715f8d9..7e4b732 100644 --- a/disco/util/paginator.py +++ b/disco/util/paginator.py @@ -5,25 +5,26 @@ class Paginator(object): """ Implements a class which provides paginated iteration over an endpoint. """ - def __init__(self, func, *args, **kwargs): + def __init__(self, func, sort_key, *args, **kwargs): self.func = func + self.sort_key = sort_key self.args = args self.kwargs = kwargs self._key = kwargs.pop('key', operator.attrgetter('id')) self._bulk = kwargs.pop('bulk', False) - self._after = kwargs.pop('after', None) + self._sort_key_value = kwargs.pop(self.sort_key, 0) self._buffer = [] def fill(self): - self.kwargs['after'] = self._after + self.kwargs[self.sort_key] = self._sort_key_value result = self.func(*self.args, **self.kwargs) if not len(result): return self._buffer.extend(result) - self._after = self._key(result[-1]) + self._sort_key_value = self._key(result[-1]) def next(self): return self.__next__() diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py index b47dc6e..6dc57c3 100644 --- a/examples/basic_plugin.py +++ b/examples/basic_plugin.py @@ -3,6 +3,17 @@ from disco.util.sanitize import S class BasicPlugin(Plugin): + @Plugin.command('auditme') + def on_auditme(self, event): + invite = event.channel.create_invite(reason='TEST AUDIT') + invite.delete(reason='TEST AUDIT 2') + # channel = event.guild.create_channel('audit-log-test', 'text', reason='TEST CREATE') + # channel.delete(reason='TEST AUDIT 2') + + @Plugin.command('ban', ' ') + def on_ban(self, event, user, reason): + event.guild.create_ban(user, reason=reason + u'\U0001F4BF') + @Plugin.command('ping') def on_ping_command(self, event): # Generally all the functionality you need to interact with is contained @@ -14,6 +25,7 @@ class BasicPlugin(Plugin): # All of Discord's events can be listened too and handled easily self.log.info(u'{}: {}'.format(event.author, event.content)) + @Plugin.command('test') @Plugin.command('echo', '') def on_echo_command(self, event, content): # Commands can take a set of arguments that are validated by Disco itself diff --git a/tests/test_reason.py b/tests/test_reason.py new file mode 100644 index 0000000..8d9b39c --- /dev/null +++ b/tests/test_reason.py @@ -0,0 +1,11 @@ +from unittest import TestCase +from utils import TestAPIClient + + +class TestReason(TestCase): + def test_set_unicode_reason(self): + api = TestAPIClient() + api.guilds_channels_modify(1, 2, 3, reason=u'yo \U0001F4BF test') + + _, kwargs = api.http.calls[0] + self.assertEquals(kwargs['headers']['X-Audit-Log-Reason'], 'yo%20%F0%9F%92%BF%20test') diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..74f9766 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,15 @@ +from disco.api.client import APIClient + + +class CallContainer(object): + def __init__(self): + self.calls = [] + + def __call__(self, *args, **kwargs): + self.calls.append((args, kwargs)) + + +class TestAPIClient(APIClient): + def __init__(self): + self.client = None + self.http = CallContainer()