Browse Source

Code optimisations and refactoring via Sourcery

pull/6483/head
Nadir Chowdhury 4 years ago
committed by GitHub
parent
commit
63ec23bac2
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      discord/audit_logs.py
  2. 14
      discord/client.py
  3. 2
      discord/errors.py
  4. 7
      discord/ext/commands/bot.py
  5. 6
      discord/ext/commands/cog.py
  6. 2
      discord/ext/commands/context.py
  7. 5
      discord/ext/commands/core.py
  8. 2
      discord/ext/tasks/__init__.py
  9. 13
      discord/gateway.py
  10. 28
      discord/iterators.py
  11. 10
      discord/member.py
  12. 13
      discord/opus.py
  13. 2
      discord/shard.py
  14. 9
      discord/state.py
  15. 8
      discord/template.py
  16. 6
      discord/user.py
  17. 5
      discord/utils.py
  18. 2
      discord/widget.py

3
discord/audit_logs.py

@ -51,8 +51,7 @@ def _transform_snowflake(entry, data):
def _transform_channel(entry, data): def _transform_channel(entry, data):
if data is None: if data is None:
return None return None
channel = entry.guild.get_channel(int(data)) or Object(id=data) return entry.guild.get_channel(int(data)) or Object(id=data)
return channel
def _transform_owner_id(entry, data): def _transform_owner_id(entry, data):
if data is None: if data is None:

14
discord/client.py

@ -754,9 +754,7 @@ class Client:
@allowed_mentions.setter @allowed_mentions.setter
def allowed_mentions(self, value): def allowed_mentions(self, value):
if value is None: if value is None or isinstance(value, AllowedMentions):
self._connection.allowed_mentions = value
elif isinstance(value, AllowedMentions):
self._connection.allowed_mentions = value self._connection.allowed_mentions = value
else: else:
raise TypeError('allowed_mentions must be AllowedMentions not {0.__class__!r}'.format(value)) raise TypeError('allowed_mentions must be AllowedMentions not {0.__class__!r}'.format(value))
@ -1227,15 +1225,13 @@ class Client:
if icon is not None: if icon is not None:
icon = utils._bytes_to_base64_data(icon) icon = utils._bytes_to_base64_data(icon)
if region is None: region = region or VoiceRegion.us_west
region = VoiceRegion.us_west.value region_value = region.value
else:
region = region.value
if code: if code:
data = await self.http.create_from_template(code, name, region, icon) data = await self.http.create_from_template(code, name, region_value, icon)
else: else:
data = await self.http.create_guild(name, region, icon) data = await self.http.create_guild(name, region_value, icon)
return Guild(data=data, state=self._connection) return Guild(data=data, state=self._connection)
# Invite management # Invite management

2
discord/errors.py

@ -104,7 +104,7 @@ class HTTPException(DiscordException):
fmt = '{0.status} {0.reason} (error code: {1})' fmt = '{0.status} {0.reason} (error code: {1})'
if len(self.text): if len(self.text):
fmt = fmt + ': {2}' fmt += ': {2}'
super().__init__(fmt.format(self.response, self.code, self.text)) super().__init__(fmt.format(self.response, self.code, self.text))

7
discord/ext/commands/bot.py

@ -165,9 +165,8 @@ class BotBase(GroupMixin):
return return
cog = context.cog cog = context.cog
if cog: if cog and Cog._get_overridden_method(cog.cog_command_error) is not None:
if Cog._get_overridden_method(cog.cog_command_error) is not None: return
return
print('Ignoring exception in command {}:'.format(context.command), file=sys.stderr) print('Ignoring exception in command {}:'.format(context.command), file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
@ -770,7 +769,7 @@ class BotBase(GroupMixin):
self._remove_module_references(lib.__name__) self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name) self._call_module_finalizers(lib, name)
self.load_extension(name) self.load_extension(name)
except Exception as e: except Exception:
# if the load failed, the remnants should have been # if the load failed, the remnants should have been
# cleaned from the load_extension function call # cleaned from the load_extension function call
# so let's load it from our old compiled library. # so let's load it from our old compiled library.

6
discord/ext/commands/cog.py

@ -96,7 +96,7 @@ class CogMeta(type):
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
name, bases, attrs = args name, bases, attrs = args
attrs['__cog_name__'] = kwargs.pop('name', name) attrs['__cog_name__'] = kwargs.pop('name', name)
attrs['__cog_settings__'] = command_attrs = kwargs.pop('command_attrs', {}) attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
description = kwargs.pop('description', None) description = kwargs.pop('description', None)
if description is None: if description is None:
@ -126,7 +126,7 @@ class CogMeta(type):
commands[elem] = value commands[elem] = value
elif inspect.iscoroutinefunction(value): elif inspect.iscoroutinefunction(value):
try: try:
is_listener = getattr(value, '__cog_listener__') getattr(value, '__cog_listener__')
except AttributeError: except AttributeError:
continue continue
else: else:
@ -192,7 +192,7 @@ class Cog(metaclass=CogMeta):
parent = lookup[parent.qualified_name] parent = lookup[parent.qualified_name]
# Update our parent's reference to our self # Update our parent's reference to our self
removed = parent.remove_command(command.name) parent.remove_command(command.name)
parent.add_command(command) parent.add_command(command)
return self return self

2
discord/ext/commands/context.py

@ -313,7 +313,7 @@ class Context(discord.abc.Messageable):
entity = bot.get_cog(entity) or bot.get_command(entity) entity = bot.get_cog(entity) or bot.get_command(entity)
try: try:
qualified_name = entity.qualified_name entity.qualified_name
except AttributeError: except AttributeError:
# if we're here then it's not a cog, group, or command. # if we're here then it's not a cog, group, or command.
return None return None

5
discord/ext/commands/core.py

@ -715,9 +715,8 @@ class Command(_BaseCommand):
except RuntimeError: except RuntimeError:
break break
if not self.ignore_extra: if not self.ignore_extra and not view.eof:
if not view.eof: raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
async def call_before_hooks(self, ctx): async def call_before_hooks(self, ctx):
# now that we're done preparing we can call the pre-command hooks # now that we're done preparing we can call the pre-command hooks

2
discord/ext/tasks/__init__.py

@ -103,7 +103,7 @@ class Loop:
now = datetime.datetime.now(datetime.timezone.utc) now = datetime.datetime.now(datetime.timezone.utc)
if now > self._next_iteration: if now > self._next_iteration:
self._next_iteration = now self._next_iteration = now
except self._valid_exception as exc: except self._valid_exception:
self._last_iteration_failed = True self._last_iteration_failed = True
if not self.reconnect: if not self.reconnect:
raise raise

13
discord/gateway.py

@ -422,16 +422,11 @@ class DiscordWebSocket:
if type(msg) is bytes: if type(msg) is bytes:
self._buffer.extend(msg) self._buffer.extend(msg)
if len(msg) >= 4: if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff':
if msg[-4:] == b'\x00\x00\xff\xff':
msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8')
self._buffer = bytearray()
else:
return
else:
return return
msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8')
self._buffer = bytearray()
msg = json.loads(msg) msg = json.loads(msg)
log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg) log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg)

28
discord/iterators.py

@ -291,13 +291,10 @@ class HistoryIterator(_AsyncIterator):
def _get_retrieve(self): def _get_retrieve(self):
l = self.limit l = self.limit
if l is None: if l is None or l > 100:
r = 100 r = 100
elif l <= 100:
r = l
else: else:
r = 100 r = l
self.retrieve = r self.retrieve = r
return r > 0 return r > 0
@ -447,13 +444,10 @@ class AuditLogIterator(_AsyncIterator):
def _get_retrieve(self): def _get_retrieve(self):
l = self.limit l = self.limit
if l is None: if l is None or l > 100:
r = 100 r = 100
elif l <= 100:
r = l
else: else:
r = 100 r = l
self.retrieve = r self.retrieve = r
return r > 0 return r > 0
@ -547,13 +541,10 @@ class GuildIterator(_AsyncIterator):
def _get_retrieve(self): def _get_retrieve(self):
l = self.limit l = self.limit
if l is None: if l is None or l > 100:
r = 100 r = 100
elif l <= 100:
r = l
else: else:
r = 100 r = l
self.retrieve = r self.retrieve = r
return r > 0 return r > 0
@ -636,13 +627,10 @@ class MemberIterator(_AsyncIterator):
def _get_retrieve(self): def _get_retrieve(self):
l = self.limit l = self.limit
if l is None: if l is None or l > 1000:
r = 1000 r = 1000
elif l <= 1000:
r = l
else: else:
r = 1000 r = l
self.retrieve = r self.retrieve = r
return r > 0 return r > 0

10
discord/member.py

@ -398,7 +398,7 @@ class Member(discord.abc.Messageable, _BaseUser):
if they have a guild specific nickname then that if they have a guild specific nickname then that
is returned instead. is returned instead.
""" """
return self.nick if self.nick is not None else self.name return self.nick or self.name
@property @property
def activity(self): def activity(self):
@ -431,11 +431,7 @@ class Member(discord.abc.Messageable, _BaseUser):
if self._user.mentioned_in(message): if self._user.mentioned_in(message):
return True return True
for role in message.role_mentions: return any(self._roles.has(role.id) for role in message.role_mentions)
if self._roles.has(role.id):
return True
return False
def permissions_in(self, channel): def permissions_in(self, channel):
"""An alias for :meth:`abc.GuildChannel.permissions_for`. """An alias for :meth:`abc.GuildChannel.permissions_for`.
@ -582,7 +578,7 @@ class Member(discord.abc.Messageable, _BaseUser):
# nick not present so... # nick not present so...
pass pass
else: else:
nick = nick if nick else '' nick = nick or ''
if self._state.self_id == self.id: if self._state.self_id == self.id:
await http.change_my_nickname(guild_id, nick, reason=reason) await http.change_my_nickname(guild_id, nick, reason=reason)
else: else:

13
discord/opus.py

@ -276,17 +276,14 @@ class _OpusStruct:
@staticmethod @staticmethod
def get_opus_version() -> str: def get_opus_version() -> str:
if not is_loaded(): if not is_loaded() and not _load_default():
if not _load_default(): raise OpusNotLoaded()
raise OpusNotLoaded()
return _lib.opus_get_version_string().decode('utf-8') return _lib.opus_get_version_string().decode('utf-8')
class Encoder(_OpusStruct): class Encoder(_OpusStruct):
def __init__(self, application=APPLICATION_AUDIO): def __init__(self, application=APPLICATION_AUDIO):
if not is_loaded(): _OpusStruct.get_opus_version()
if not _load_default():
raise OpusNotLoaded()
self.application = application self.application = application
self._state = self._create_state() self._state = self._create_state()
@ -342,9 +339,7 @@ class Encoder(_OpusStruct):
class Decoder(_OpusStruct): class Decoder(_OpusStruct):
def __init__(self): def __init__(self):
if not is_loaded(): _OpusStruct.get_opus_version()
if not _load_default():
raise OpusNotLoaded()
self._state = self._create_state() self._state = self._create_state()

2
discord/shard.py

@ -413,7 +413,7 @@ class AutoShardedClient(Client):
self._connection.shard_count = self.shard_count self._connection.shard_count = self.shard_count
shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count) shard_ids = self.shard_ids or range(self.shard_count)
self._connection.shard_ids = shard_ids self._connection.shard_ids = shard_ids
for shard_id in shard_ids: for shard_id in shard_ids:

9
discord/state.py

@ -896,7 +896,7 @@ class ConnectionState:
log.debug('GUILD_DELETE referencing an unknown guild ID: %s. Discarding.', data['id']) log.debug('GUILD_DELETE referencing an unknown guild ID: %s. Discarding.', data['id'])
return return
if data.get('unavailable', False) and guild is not None: if data.get('unavailable', False):
# GUILD_DELETE with unavailable being True means that the # GUILD_DELETE with unavailable being True means that the
# guild that was available is now currently unavailable # guild that was available is now currently unavailable
guild.unavailable = True guild.unavailable = True
@ -928,10 +928,9 @@ class ConnectionState:
def parse_guild_ban_remove(self, data): def parse_guild_ban_remove(self, data):
guild = self._get_guild(int(data['guild_id'])) guild = self._get_guild(int(data['guild_id']))
if guild is not None: if guild is not None and 'user' in data:
if 'user' in data: user = self.store_user(data['user'])
user = self.store_user(data['user']) self.dispatch('member_unban', guild, user)
self.dispatch('member_unban', guild, user)
def parse_guild_role_create(self, data): def parse_guild_role_create(self, data):
guild = self._get_guild(int(data['guild_id'])) guild = self._get_guild(int(data['guild_id']))

8
discord/template.py

@ -168,12 +168,10 @@ class Template:
if icon is not None: if icon is not None:
icon = _bytes_to_base64_data(icon) icon = _bytes_to_base64_data(icon)
if region is None: region = region or VoiceRegion.us_west
region = VoiceRegion.us_west.value region_value = region.value
else:
region = region.value
data = await self._state.http.create_from_template(self.code, name, region, icon) data = await self._state.http.create_from_template(self.code, name, region_value, icon)
return Guild(data=data, state=self._state) return Guild(data=data, state=self._state)
async def sync(self): async def sync(self):

6
discord/user.py

@ -274,11 +274,7 @@ class BaseUser(_BaseUser):
if message.mention_everyone: if message.mention_everyone:
return True return True
for user in message.mentions: return any(user.id == self.id for user in message.mentions)
if user.id == self.id:
return True
return False
class ClientUser(BaseUser): class ClientUser(BaseUser):
"""Represents your Discord user. """Represents your Discord user.

5
discord/utils.py

@ -419,11 +419,8 @@ def _string_width(string, *, _IS_ASCII=_IS_ASCII):
return match.endpos return match.endpos
UNICODE_WIDE_CHAR_TYPE = 'WFA' UNICODE_WIDE_CHAR_TYPE = 'WFA'
width = 0
func = unicodedata.east_asian_width func = unicodedata.east_asian_width
for char in string: return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string)
width += 2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1
return width
def resolve_invite(invite): def resolve_invite(invite):
""" """

2
discord/widget.py

@ -156,7 +156,7 @@ class WidgetMember(BaseUser):
@property @property
def display_name(self): def display_name(self):
""":class:`str`: Returns the member's display name.""" """:class:`str`: Returns the member's display name."""
return self.nick if self.nick else self.name return self.nick or self.name
class Widget: class Widget:
"""Represents a :class:`Guild` widget. """Represents a :class:`Guild` widget.

Loading…
Cancel
Save