diff --git a/disco/api/http.py b/disco/api/http.py index a11f6ac..5a05f7a 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -137,7 +137,7 @@ class HTTPClient(LoggingClass): # If we got a success status code, just return the data if r.status_code < 400: return r - elif 400 < r.status_code < 500: + elif r.status_code != 429 and 400 < r.status_code < 500: raise APIException('Request failed', r.status_code, r.content) else: if r.status_code == 429: diff --git a/disco/api/ratelimit.py b/disco/api/ratelimit.py index 80ed940..8d91532 100644 --- a/disco/api/ratelimit.py +++ b/disco/api/ratelimit.py @@ -56,7 +56,7 @@ class RateLimiter(object): return self.states[route].wait(timeout) if self.states[route].next_will_ratelimit(): - gevent.spawn(self.states[route].cooldown).wait(timeout) + gevent.spawn(self.states[route].cooldown).get(True, timeout) return True diff --git a/disco/bot/parser.py b/disco/bot/parser.py index 7fa77d9..9963e5c 100644 --- a/disco/bot/parser.py +++ b/disco/bot/parser.py @@ -54,7 +54,7 @@ class ArgumentSet(object): def __init__(self, args=None, custom_types=None): self.args = args or [] self.types = copy.copy(TYPE_MAP) - self.types.update(custom_types) + self.types.update(custom_types or {}) def convert(self, types, value): for typ_name in types: diff --git a/disco/types/channel.py b/disco/types/channel.py index 2d18d15..587a9fd 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -6,7 +6,7 @@ from disco.util.cache import cached_property from disco.util.types import ListToDictType from disco.types.base import BaseType from disco.types.user import User -from disco.types.permissions import * +from disco.types.permissions import PermissionType, Permissions, Permissible from disco.voice.client import VoiceClient @@ -31,7 +31,7 @@ class PermissionOverwrite(BaseType): deny = PermissionType() -class Channel(BaseType): +class Channel(BaseType, Permissible): id = skema.SnowflakeType() guild_id = skema.SnowflakeType(required=False) @@ -46,6 +46,22 @@ class Channel(BaseType): overwrites = ListToDictType('id', skema.ModelType(PermissionOverwrite), stored_name='permission_overwrites') + def get_permissions(self, user): + if not self.guild_id: + return Permissions.ADMINISTRATOR + + member = self.guild.members.get(user.id) + base = self.guild.get_permissions(user) + + for ow in self.overwrites.values(): + if ow.id != user.id and ow.id not in member.roles: + continue + + base -= ow.deny + base += ow.allow + + return base + @property def is_guild(self): return self.type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE) diff --git a/disco/types/guild.py b/disco/types/guild.py index d26749f..7802159 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -3,11 +3,11 @@ import copy from disco.api.http import APIException from disco.util import to_snowflake -from disco.types.base import BaseType from disco.util.types import PreHookType, ListToDictType +from disco.types.base import BaseType from disco.types.user import User from disco.types.voice import VoiceState -from disco.types.permissions import PermissionType +from disco.types.permissions import PermissionType, PermissionValue, Permissions, Permissible from disco.types.channel import Channel @@ -51,7 +51,7 @@ class GuildMember(BaseType): return self.user.id -class Guild(BaseType): +class Guild(BaseType, Permissible): id = skema.SnowflakeType() owner_id = skema.SnowflakeType() @@ -76,6 +76,18 @@ class Guild(BaseType): emojis = ListToDictType('id', skema.ModelType(Emoji)) voice_states = ListToDictType('session_id', skema.ModelType(VoiceState)) + def get_permissions(self, user): + if self.owner_id == user.id: + return PermissionValue(Permissions.ADMINISTRATOR) + + member = self.get_member(user) + value = PermissionValue(self.roles.get(self.id).permissions) + + for role in map(self.roles.get, member.roles): + value += role.permissions + + return value + def get_voice_state(self, user): user = to_snowflake(user) diff --git a/disco/types/permissions.py b/disco/types/permissions.py index a46767c..76c0ce9 100644 --- a/disco/types/permissions.py +++ b/disco/types/permissions.py @@ -1,9 +1,8 @@ from skema import NumberType -from holster.enum import Enum +from holster.enum import Enum, EnumAttr Permissions = Enum( - NONE=0, CREATE_INSTANT_INVITE=1 << 0, KICK_MEMBERS=1 << 1, BAN_MEMBERS=1 << 2, @@ -34,9 +33,48 @@ Permissions = Enum( class PermissionValue(object): - def __init__(self, value): + def __init__(self, value=0): + if isinstance(value, EnumAttr) or isinstance(value, PermissionValue): + value = value.value + self.value = value + def can(self, *perms): + for perm in perms: + if isinstance(perm, EnumAttr): + perm = perm.value + if not (self.value & perm) == perm: + return False + return True + + def add(self, other): + if isinstance(other, PermissionValue): + self.value |= other.value + elif isinstance(other, int): + self.value |= other + elif isinstance(other, EnumAttr): + setattr(self, other.name, True) + else: + raise TypeError('Cannot PermissionValue.add from type {}'.format(type(other))) + return self + + def sub(self, other): + if isinstance(other, PermissionValue): + self.value &= ~other.value + elif isinstance(other, int): + self.value &= other + elif isinstance(other, EnumAttr): + setattr(self, other.name, False) + else: + raise TypeError('Cannot PermissionValue.sub from type {}'.format(type(other))) + return self + + def __iadd__(self, other): + return self.add(other) + + def __isub__(self, other): + return self.sub(other) + def __getattribute__(self, name): if name in Permissions.attrs: return (self.value & Permissions[name].value) == Permissions[name].value @@ -69,3 +107,9 @@ class PermissionValue(object): class PermissionType(NumberType): def __init__(self, *args, **kwargs): super(PermissionType, self).__init__(number_class=PermissionValue, number_type='PermissionValue', *args, **kwargs) + + +class Permissible(object): + def can(self, user, *args): + perms = self.get_permissions(user) + return perms.administrator or perms.can(*args) diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py index 36134e6..11a1b66 100644 --- a/examples/basic_plugin.py +++ b/examples/basic_plugin.py @@ -1,10 +1,12 @@ import gevent import sys +import json from disco import VERSION from disco.cli import disco_main from disco.bot import Bot from disco.bot.plugin import Plugin +from disco.types.permissions import Permissions class BasicPlugin(Plugin): @@ -81,6 +83,17 @@ class BasicPlugin(Plugin): gevent.sleep(1) vc.disconnect() + @Plugin.command('lol') + def on_lol(self, event): + event.msg.reply("{}".format(event.channel.can(event.msg.author, Permissions.MANAGE_EMOJIS))) + + @Plugin.command('perms') + def on_perms(self, event): + perms = event.channel.get_permissions(event.msg.author) + event.msg.reply('```json\n{}\n```'.format( + json.dumps(perms.to_dict(), sort_keys=True, indent=2, separators=(',', ': ')) + )) + if __name__ == '__main__': bot = Bot(disco_main()) bot.add_plugin(BasicPlugin)