diff --git a/discord/client.py b/discord/client.py index 7059f8b84..45c9bfa72 100644 --- a/discord/client.py +++ b/discord/client.py @@ -2021,11 +2021,11 @@ class Client: yield from response.release() @asyncio.coroutine - def _replace_roles(self, member, *roles): + def _replace_roles(self, member, roles): url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member) payload = { - 'roles': list(roles) + 'roles': roles } r = yield from aiohttp.patch(url, headers=self.headers, data=utils.to_json(payload), loop=self.loop) @@ -2059,8 +2059,8 @@ class Client: Adding roles failed. """ - new_roles = {role.id for role in itertools.chain(member.roles, roles)} - yield from self._replace_roles(member, *new_roles) + new_roles = utils._unique(itertools.chain(member.roles, roles)) + yield from self._replace_roles(member, new_roles) @asyncio.coroutine def remove_roles(self, member, *roles): @@ -2086,9 +2086,16 @@ class Client: HTTPException Removing roles failed. """ - new_roles = {role.id for role in member.roles} - new_roles = new_roles.difference(role.id for role in roles) - yield from self._replace_roles(member, *new_roles) + new_roles = [x.id for x in member.roles] + remove = [] + for index, role in enumerate(roles): + if role.id in new_roles: + remove.append(index) + + for index in reversed(remove): + del new_roles[index] + + yield from self._replace_roles(member, new_roles) @asyncio.coroutine def replace_roles(self, member, *roles): @@ -2120,8 +2127,8 @@ class Client: Removing roles failed. """ - new_roles = {role.id for role in roles} - yield from self._replace_roles(member, *new_roles) + new_roles = utils._unique(roles) + yield from self._replace_roles(member, new_roles) @asyncio.coroutine def create_role(self, server, **fields): diff --git a/discord/utils.py b/discord/utils.py index 74c070c88..e74c2e99d 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -137,6 +137,11 @@ def get(iterable, **attrs): return find(predicate, iterable) +def _unique(iterable): + seen = set() + adder = seen.add + return [x for x in iterable if not (x in seen or adder(x))] + def _null_event(*args, **kwargs): pass