diff --git a/discord/channel.py b/discord/channel.py index 95aaaea80..91bf878c9 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -26,7 +26,7 @@ DEALINGS IN THE SOFTWARE. import copy from . import utils from .permissions import Permissions, PermissionOverwrite -from .enums import ChannelType +from .enums import ChannelType, try_enum from collections import namedtuple from .mixins import Hashable from .role import Role @@ -54,68 +54,63 @@ class Channel(Hashable): Attributes ----------- - name : str + name: str The channel name. - server : :class:`Server` + server: :class:`Server` The server the channel belongs to. - id : str + id: str The channel ID. - topic : Optional[str] + topic: Optional[str] The channel's topic. None if it doesn't exist. - is_private : bool + is_private: bool ``True`` if the channel is a private channel (i.e. PM). ``False`` in this case. - position : int + position: int The position in the channel list. This is a number that starts at 0. e.g. the top channel is position 0. The position varies depending on being a voice channel or a text channel, so a 0 position voice channel is on top of the voice channel list. - type : :class:`ChannelType` + type: :class:`ChannelType` The channel type. There is a chance that the type will be ``str`` if the channel type is not within the ones recognised by the enumerator. - bitrate : int + bitrate: int The channel's preferred audio bitrate in bits per second. voice_members A list of :class:`Members` that are currently inside this voice channel. If :attr:`type` is not :attr:`ChannelType.voice` then this is always an empty array. - user_limit : int + user_limit: int The channel's limit for number of members that can be in a voice channel. """ - __slots__ = [ 'voice_members', 'name', 'id', 'server', 'topic', 'position', - 'is_private', 'type', 'bitrate', 'user_limit', - '_permission_overwrites' ] + __slots__ = ( 'voice_members', 'name', 'id', 'server', 'topic', + 'type', 'bitrate', 'user_limit', '_state', 'position', + '_permission_overwrites' ) - def __init__(self, **kwargs): - self._update(**kwargs) + def __init__(self, *, state, server, data): + self._state = state + self._update(server, data) self.voice_members = [] def __str__(self): return self.name - def _update(self, **kwargs): - self.name = kwargs.get('name') - self.server = kwargs.get('server') - self.id = kwargs.get('id') - self.topic = kwargs.get('topic') - self.is_private = False - self.position = kwargs.get('position') - self.bitrate = kwargs.get('bitrate') - self.type = kwargs.get('type') - self.user_limit = kwargs.get('user_limit') - try: - self.type = ChannelType(self.type) - except: - pass - + def _update(self, server, data): + self.server = server + self.name = data['name'] + self.id = data['id'] + self.topic = data.get('topic') + self.position = data['position'] + self.bitrate = data.get('bitrate') + self.type = data['type'] + self.user_limit = data.get('user_limit') self._permission_overwrites = [] everyone_index = 0 everyone_id = self.server.id - for index, overridden in enumerate(kwargs.get('permission_overwrites', [])): + for index, overridden in enumerate(data.get('permission_overwrites', [])): overridden_id = overridden['id'] self._permission_overwrites.append(Overwrites(**overridden)) - if overridden.get('type') == 'member': + if overridden['type'] == 'member': continue if overridden_id == everyone_id: @@ -151,6 +146,10 @@ class Channel(Hashable): """bool : Indicates if this is the default channel for the :class:`Server` it belongs to.""" return self.server.id == self.id + @property + def is_private(self): + return False + @property def mention(self): """str : The string that allows you to mention the channel.""" @@ -354,19 +353,20 @@ class PrivateChannel(Hashable): :attr:`ChannelType.group` then this is always ``None``. """ - __slots__ = ['id', 'recipients', 'type', 'owner', 'icon', 'name', 'me'] + __slots__ = ['id', 'recipients', 'type', 'owner', 'icon', 'name', 'me', '_state'] - def __init__(self, me, **kwargs): - self.recipients = [User(**u) for u in kwargs['recipients']] - self.id = kwargs['id'] + def __init__(self, *, me, state, data): + self._state = state + self.recipients = [state.try_insert_user(u) for u in data['recipients']] + self.id = data['id'] self.me = me - self.type = ChannelType(kwargs['type']) - self._update_group(**kwargs) + self.type = ChannelType(data['type']) + self._update_group(data) - def _update_group(self, **kwargs): - owner_id = kwargs.get('owner_id') - self.icon = kwargs.get('icon') - self.name = kwargs.get('name') + def _update_group(self, data): + owner_id = data.get('owner_id') + self.icon = data.get('icon') + self.name = data.get('name') self.owner = utils.find(lambda u: u.id == owner_id, self.recipients) @property diff --git a/discord/client.py b/discord/client.py index f8cd4c324..b1dd1c222 100644 --- a/discord/client.py +++ b/discord/client.py @@ -32,7 +32,6 @@ from .server import Server from .message import Message from .invite import Invite from .object import Object -from .reaction import Reaction from .role import Role from .errors import * from .state import ConnectionState @@ -145,16 +144,15 @@ class Client: self.shard_id = options.get('shard_id') self.shard_count = options.get('shard_count') - max_messages = options.get('max_messages') - if max_messages is None or max_messages < 100: - max_messages = 5000 - - self.connection = ConnectionState(self.dispatch, self.request_offline_members, - self._syncer, max_messages, loop=self.loop) - connector = options.pop('connector', None) self.http = HTTPClient(connector, loop=self.loop) + self.connection = ConnectionState(dispatch=self.dispatch, + chunker=self.request_offline_members, + syncer=self._syncer, + http=self.http, loop=self.loop, + **options) + self._closed = asyncio.Event(loop=self.loop) self._is_logged_in = asyncio.Event(loop=self.loop) self._is_ready = asyncio.Event(loop=self.loop) @@ -914,7 +912,7 @@ class Client: raise InvalidArgument('user argument must be a User') data = yield from self.http.start_private_message(user.id) - channel = PrivateChannel(me=self.user, **data) + channel = PrivateChannel(me=self.user, data=data, state=self.connection.ctx) self.connection._add_private_channel(channel) return channel @@ -1151,7 +1149,7 @@ class Client: data = yield from self.http.send_message(channel_id, content, guild_id=guild_id, tts=tts, embed=embed) channel = self.get_channel(data.get('channel_id')) - message = self.connection._create_message(channel=channel, **data) + message = Message(channel=channel, state=self.connection.ctx, data=data) return message @asyncio.coroutine @@ -1233,7 +1231,7 @@ class Client: data = yield from self.http.send_file(channel_id, buffer, guild_id=guild_id, filename=filename, content=content, tts=tts) channel = self.get_channel(data.get('channel_id')) - message = self.connection._create_message(channel=channel, **data) + message = Message(channel=channel, state=self.connection.ctx, data=data) return message @asyncio.coroutine @@ -1438,7 +1436,7 @@ class Client: embed = embed.to_dict() if embed else None guild_id = channel.server.id if not getattr(channel, 'is_private', True) else None data = yield from self.http.edit_message(message.id, channel.id, content, guild_id=guild_id, embed=embed) - return self.connection._create_message(channel=channel, **data) + return Message(channel=channel, state=self.connection.ctx, data=data) @asyncio.coroutine def get_message(self, channel, id): @@ -1471,7 +1469,7 @@ class Client: """ data = yield from self.http.get_message(channel.id, id) - return self.connection._create_message(channel=channel, **data) + return Message(channel=channel, state=self.connection.ctx, data=data) @asyncio.coroutine def pin_message(self, message): @@ -1541,7 +1539,7 @@ class Client: """ data = yield from self.http.pins_from(channel.id) - return [self.connection._create_message(channel=channel, **m) for m in data] + return [Message(channel=channel, state=self.connection.ctx, data=m) for m in data] def _logs_from(self, channel, limit=100, before=None, after=None, around=None): """|coro| @@ -1622,7 +1620,7 @@ class Client: def generator(data): for message in data: - yield self.connection._create_message(channel=channel, **message) + yield Message(channel=channel, state=self.connection.ctx, data=message) result = [] while limit > 0: @@ -2161,7 +2159,7 @@ class Client: perms.append(payload) data = yield from self.http.create_channel(server.id, name, str(type), permission_overwrites=perms) - channel = Channel(server=server, **data) + channel = Channel(server=server, state=self.connection.ctx, data=data) return channel @asyncio.coroutine @@ -2275,7 +2273,7 @@ class Client: region = region.name data = yield from self.http.create_server(name, region, icon) - return Server(**data) + return Server(data=data, state=self.connection.ctx) @asyncio.coroutine def edit_server(self, server, **fields): @@ -2397,7 +2395,7 @@ class Client: """ data = yield from self.http.get_bans(server.id) - return [User(**user['user']) for user in data] + return [self.connection.try_insert_user(user) for user in data] @asyncio.coroutine def prune_members(self, server, *, days): @@ -2514,7 +2512,7 @@ class Client: img = utils._bytes_to_base64_data(image) data = yield from self.http.create_custom_emoji(server.id, name, img) - return Emoji(server=server, **data) + return Emoji(server=server, data=data, state=self.connection.ctx) @asyncio.coroutine def delete_custom_emoji(self, emoji): @@ -2989,7 +2987,7 @@ class Client: """ data = yield from self.http.create_role(server.id) - role = Role(server=server, **data) + role = Role(server=server, data=data, state=self.connection.ctx) # we have to call edit because you can't pass a payload to the # http request currently. @@ -3271,7 +3269,7 @@ class Client: data = yield from self.http.application_info() return AppInfo(id=data['id'], name=data['name'], description=data['description'], icon=data['icon'], - owner=User(**data['owner'])) + owner=User(state=self.connection.ctx, data=data['owner'])) @asyncio.coroutine def get_user_info(self, user_id): @@ -3300,4 +3298,4 @@ class Client: Fetching the user failed. """ data = yield from self.http.get_user_info(user_id) - return User(**data) + return User(state=self.connection.ctx, data=data) diff --git a/discord/emoji.py b/discord/emoji.py index 82384aa5d..81d5dabac 100644 --- a/discord/emoji.py +++ b/discord/emoji.py @@ -68,11 +68,12 @@ class Emoji(Hashable): A list of :class:`Role` that is allowed to use this emoji. If roles is empty, the emoji is unrestricted. """ - __slots__ = ["require_colons", "managed", "id", "name", "roles", 'server'] + __slots__ = ('require_colons', 'managed', 'id', 'name', 'roles', 'server', '_state') - def __init__(self, **kwargs): - self.server = kwargs.pop('server') - self._from_data(kwargs) + def __init__(self, *, server, state, data): + self.server = server + self._state = state + self._from_data(data) def _from_data(self, emoji): self.require_colons = emoji.get('require_colons') @@ -86,9 +87,10 @@ class Emoji(Hashable): def _iterator(self): for attr in self.__slots__: - value = getattr(self, attr, None) - if value is not None: - yield (attr, value) + if attr[0] != '_': + value = getattr(self, attr, None) + if value is not None: + yield (attr, value) def __iter__(self): return self._iterator() diff --git a/discord/invite.py b/discord/invite.py index b0fedccf6..4e19b2d6d 100644 --- a/discord/invite.py +++ b/discord/invite.py @@ -76,23 +76,24 @@ class Invite(Hashable): """ - __slots__ = [ 'max_age', 'code', 'server', 'revoked', 'created_at', 'uses', - 'temporary', 'max_uses', 'xkcd', 'inviter', 'channel' ] - - def __init__(self, **kwargs): - self.max_age = kwargs.get('max_age') - self.code = kwargs.get('code') - self.server = kwargs.get('server') - self.revoked = kwargs.get('revoked') - self.created_at = parse_time(kwargs.get('created_at')) - self.temporary = kwargs.get('temporary') - self.uses = kwargs.get('uses') - self.max_uses = kwargs.get('max_uses') - self.xkcd = kwargs.get('xkcdpass') - - inviter_data = kwargs.get('inviter') - self.inviter = None if inviter_data is None else User(**inviter_data) - self.channel = kwargs.get('channel') + __slots__ = ( 'max_age', 'code', 'server', 'revoked', 'created_at', 'uses', + 'temporary', 'max_uses', 'xkcd', 'inviter', 'channel', '_state' ) + + def __init__(self, *, state, data): + self._state = state + self.max_age = data.get('max_age') + self.code = data.get('code') + self.server = data.get('server') + self.revoked = data.get('revoked') + self.created_at = parse_time(data.get('created_at')) + self.temporary = data.get('temporary') + self.uses = data.get('uses') + self.max_uses = data.get('max_uses') + self.xkcd = data.get('xkcdpass') + + inviter_data = data.get('inviter') + self.inviter = None if inviter_data is None else User(state=state, data=data) + self.channel = data.get('channel') def __str__(self): return self.url diff --git a/discord/iterators.py b/discord/iterators.py index 2ea514367..63a8776d3 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -72,7 +72,6 @@ class LogsFromIterator: def __init__(self, client, channel, limit, before=None, after=None, around=None, reverse=False): self.client = client - self.connection = client.connection self.channel = channel self.limit = limit self.before = before @@ -81,6 +80,7 @@ class LogsFromIterator: self.reverse = reverse self._filter = None # message dict -> bool self.messages = asyncio.Queue() + self.ctx = client.connection.ctx if self.around: if self.limit > 101: @@ -92,18 +92,18 @@ class LogsFromIterator: self._retrieve_messages = self._retrieve_messages_around_strategy if self.before and self.after: - self._filter = lambda m: int(self.after.id) < int(m['id']) < int(self.before.id) + self._filter = lambda m: self.after.id < m['id'] < self.before.id elif self.before: - self._filter = lambda m: int(m['id']) < int(self.before.id) + self._filter = lambda m: m['id'] < self.before.id elif self.after: - self._filter = lambda m: int(self.after.id) < int(m['id']) + self._filter = lambda m: self.after.id < m['id'] elif self.before and self.after: if self.reverse: self._retrieve_messages = self._retrieve_messages_after_strategy - self._filter = lambda m: int(m['id']) < int(self.before.id) + self._filter = lambda m: m['id'] < self.before.id else: self._retrieve_messages = self._retrieve_messages_before_strategy - self._filter = lambda m: int(m['id']) > int(self.after.id) + self._filter = lambda m: m['id'] > self.after.id elif self.after: self._retrieve_messages = self._retrieve_messages_after_strategy else: @@ -126,9 +126,7 @@ class LogsFromIterator: if self._filter: data = filter(self._filter, data) for element in data: - yield from self.messages.put( - self.connection._create_message( - channel=self.channel, **element)) + yield from self.messages.put(Message(channel=self.channel, state=self.ctx, data=element)) @asyncio.coroutine def _retrieve_messages(self, retrieve): @@ -141,7 +139,7 @@ class LogsFromIterator: data = yield from self.client._logs_from(self.channel, retrieve, before=self.before) if len(data): self.limit -= retrieve - self.before = Object(id=data[-1]['id']) + self.before = Object(id=int(data[-1]['id'])) return data @asyncio.coroutine @@ -150,7 +148,7 @@ class LogsFromIterator: data = yield from self.client._logs_from(self.channel, retrieve, after=self.after) if len(data): self.limit -= retrieve - self.after = Object(id=data[0]['id']) + self.after = Object(id=int(data[0]['id'])) return data @asyncio.coroutine diff --git a/discord/member.py b/discord/member.py index 50ad184a7..4d0f86b98 100644 --- a/discord/member.py +++ b/discord/member.py @@ -28,9 +28,11 @@ from .user import User from .game import Game from .permissions import Permissions from . import utils -from .enums import Status, ChannelType +from .enums import Status, ChannelType, try_enum from .colour import Colour + import copy +import inspect class VoiceState: """Represents a Discord user's voice state. @@ -52,8 +54,8 @@ class VoiceState: is not currently in a voice channel. """ - __slots__ = [ 'session_id', 'deaf', 'mute', 'self_mute', - 'self_deaf', 'is_afk', 'voice_channel' ] + __slots__ = ( 'session_id', 'deaf', 'mute', 'self_mute', + 'self_deaf', 'is_afk', 'voice_channel' ) def __init__(self, **kwargs): self.session_id = kwargs.get('session_id') @@ -74,12 +76,57 @@ def flatten_voice_states(cls): setattr(cls, attr, property(getter)) return cls +def flatten_user(cls): + for attr, value in User.__dict__.items(): + # ignore private/special methods + if attr.startswith('_'): + continue + + # don't override what we already have + if attr in cls.__dict__: + continue + + # if it's a slotted attribute or a property, redirect it + # slotted members are implemented as member_descriptors in Type.__dict__ + if hasattr(value, '__get__'): + def getter(self, x=attr): + return getattr(self._user, x) + setattr(cls, attr, property(getter, doc='Equivalent to :attr:`User.%s`' % attr)) + else: + # probably a member function by now + def generate_function(x): + def general(self, *args, **kwargs): + return getattr(self._user, x)(*args, **kwargs) + + general.__name__ = x + return general + + func = generate_function(attr) + func.__doc__ = value.__doc__ + setattr(cls, attr, func) + + return cls + @flatten_voice_states -class Member(User): +@flatten_user +class Member: """Represents a Discord member to a :class:`Server`. - This is a subclass of :class:`User` that extends more functionality - that server members have such as roles and permissions. + This implements a lot of the functionality of :class:`User`. + + Supported Operations: + + +-----------+-----------------------------------------------+ + | Operation | Description | + +===========+===============================================+ + | x == y | Checks if two members are equal. | + +-----------+-----------------------------------------------+ + | x != y | Checks if two members are not equal. | + +-----------+-----------------------------------------------+ + | hash(x) | Return the member's hash. | + +-----------+-----------------------------------------------+ + | str(x) | Returns the member's name with discriminator. | + +-----------+-----------------------------------------------+ Attributes ---------- @@ -103,18 +150,31 @@ class Member(User): The server specific nickname of the user. """ - __slots__ = [ 'roles', 'joined_at', 'status', 'game', 'server', 'nick', 'voice' ] + __slots__ = ('roles', 'joined_at', 'status', 'game', 'server', 'nick', 'voice', '_user', '_state') - def __init__(self, **kwargs): - super().__init__(**kwargs.get('user')) - self.voice = VoiceState(**kwargs) - self.joined_at = utils.parse_time(kwargs.get('joined_at')) - self.roles = kwargs.get('roles', []) + def __init__(self, *, data, server, state): + self._state = state + self._user = state.try_insert_user(data['user']) + self.voice = VoiceState(**data) + self.joined_at = utils.parse_time(data.get('joined_at')) + self.roles = data.get('roles', []) self.status = Status.offline - game = kwargs.get('game', {}) + game = data.get('game', {}) self.game = Game(**game) if game else None - self.server = kwargs.get('server', None) - self.nick = kwargs.get('nick', None) + self.server = server + self.nick = data.get('nick', None) + + def __str__(self): + return self._user.__str__() + + def __eq__(self, other): + return isinstance(other, Member) and other._user.id == self._user.id and self.server.id == other.server.id + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self._user.id) def _update_voice_state(self, **kwargs): self.voice.self_mute = kwargs.get('self_mute', False) @@ -146,6 +206,35 @@ class Member(User): ret.voice = copy.copy(self.voice) return ret + def _update(self, data, user): + self._user.name = user['username'] + self._user.discriminator = user['discriminator'] + self._user.avatar = user['avatar'] + self._user.bot = user.get('bot', False) + + # the nickname change is optional, + # if it isn't in the payload then it didn't change + if 'nick' in data: + self.nick = data['nick'] + + # update the roles + self.roles = [self.server.default_role] + for role in self.server.roles: + if role.id in data['roles']: + self.roles.append(role) + + # sort the roles by ID since they can be "randomised" + self.roles.sort(key=lambda r: int(r.id)) + + def _presence_update(self, data, user): + self.status = try_enum(Status, data['status']) + game = data.get('game', {}) + self.game = Game(**game) if game else None + u = self._user + u.name = user.get('username', u.name) + u.avatar = user.get('avatar', u.avatar) + u.discriminator = user.get('discriminator', u.discriminator) + @property def colour(self): """A property that returns a :class:`Colour` denoting the rendered colour @@ -173,13 +262,20 @@ class Member(User): @property def mention(self): + """Returns a string that mentions the member.""" if self.nick: return '<@!{}>'.format(self.id) return '<@{}>'.format(self.id) def mentioned_in(self, message): - mentioned = super().mentioned_in(message) - if mentioned: + """Checks if the member is mentioned in the specified message. + + Parameters + ----------- + message: :class:`Message` + The message to check if you're mentioned in. + """ + if self._user.mentioned_in(message): return True for role in message.role_mentions: diff --git a/discord/message.py b/discord/message.py index e6e2fdd17..c50f8ab5a 100644 --- a/discord/message.py +++ b/discord/message.py @@ -107,43 +107,44 @@ class Message: Reactions to a message. Reactions can be either custom emoji or standard unicode emoji. """ - __slots__ = [ 'edited_timestamp', 'timestamp', 'tts', 'content', 'channel', + __slots__ = ( 'edited_timestamp', 'timestamp', 'tts', 'content', 'channel', 'mention_everyone', 'embeds', 'id', 'mentions', 'author', - 'channel_mentions', 'server', '_raw_mentions', 'attachments', - '_clean_content', '_raw_channel_mentions', 'nonce', 'pinned', - 'role_mentions', '_raw_role_mentions', 'type', 'call', - '_system_content', 'reactions' ] + 'channel_mentions', 'server', '_cs_raw_mentions', 'attachments', + '_cs_clean_content', '_cs_raw_channel_mentions', 'nonce', 'pinned', + 'role_mentions', '_cs_raw_role_mentions', 'type', 'call', + '_cs_system_content', '_state', 'reactions' ) - def __init__(self, **kwargs): + def __init__(self, *, state, channel, data): + self._state = state self.reactions = kwargs.pop('reactions') for reaction in self.reactions: reaction.message = self - self._update(**kwargs) + self._update(channel, data) - def _update(self, **data): + def _update(self, channel, data): # at the moment, the timestamps seem to be naive so they have no time zone and operate on UTC time. # we can use this to our advantage to use strptime instead of a complicated parsing routine. # example timestamp: 2015-08-21T12:03:45.782000+00:00 # sometimes the .%f modifier is missing - self.edited_timestamp = utils.parse_time(data.get('edited_timestamp')) - self.timestamp = utils.parse_time(data.get('timestamp')) + self.edited_timestamp = utils.parse_time(data['edited_timestamp']) + self.timestamp = utils.parse_time(data['timestamp']) self.tts = data.get('tts', False) self.pinned = data.get('pinned', False) - self.content = data.get('content') - self.mention_everyone = data.get('mention_everyone') - self.embeds = data.get('embeds') - self.id = data.get('id') - self.channel = data.get('channel') - self.author = User(**data.get('author', {})) + self.content = data['content'] + self.mention_everyone = data['mention_everyone'] + self.embeds = data['embeds'] + self.id = data['id'] + self.channel = channel + self.author = self._state.try_insert_user(data['author']) self.nonce = data.get('nonce') - self.attachments = data.get('attachments') + self.attachments = data['attachments'] self.type = try_enum(MessageType, data.get('type')) - self._handle_upgrades(data.get('channel_id')) + self._handle_upgrades(data['channel_id']) self._handle_mentions(data.get('mentions', []), data.get('mention_roles', [])) self._handle_call(data.get('call')) # clear the cached properties - cached = filter(lambda attr: attr[0] == '_', self.__slots__) + cached = filter(lambda attr: attr.startswith('_cs_'), self.__slots__) for attr in cached: try: delattr(self, attr) @@ -155,7 +156,7 @@ class Message: self.channel_mentions = [] self.role_mentions = [] if getattr(self.channel, 'is_private', True): - self.mentions = [User(**m) for m in mentions] + self.mentions = [self._state.try_insert_user(m) for m in mentions] return if self.server is not None: @@ -193,7 +194,7 @@ class Message: call['participants'] = participants self.call = CallMessage(message=self, **call) - @utils.cached_slot_property('_raw_mentions') + @utils.cached_slot_property('_cs_raw_mentions') def raw_mentions(self): """A property that returns an array of user IDs matched with the syntax of <@user_id> in the message content. @@ -203,21 +204,21 @@ class Message: """ return re.findall(r'<@!?([0-9]+)>', self.content) - @utils.cached_slot_property('_raw_channel_mentions') + @utils.cached_slot_property('_cs_raw_channel_mentions') def raw_channel_mentions(self): """A property that returns an array of channel IDs matched with the syntax of <#channel_id> in the message content. """ return re.findall(r'<#([0-9]+)>', self.content) - @utils.cached_slot_property('_raw_role_mentions') + @utils.cached_slot_property('_cs_raw_role_mentions') def raw_role_mentions(self): """A property that returns an array of role IDs matched with the syntax of <@&role_id> in the message content. """ return re.findall(r'<@&([0-9]+)>', self.content) - @utils.cached_slot_property('_clean_content') + @utils.cached_slot_property('_cs_clean_content') def clean_content(self): """A property that returns the content in a "cleaned up" manner. This basically means that mentions are transformed @@ -288,7 +289,7 @@ class Message: if found is not None: self.author = found - @utils.cached_slot_property('_system_content') + @utils.cached_slot_property('_cs_system_content') def system_content(self): """A property that returns the content that is rendered regardless of the :attr:`Message.type`. diff --git a/discord/role.py b/discord/role.py index c375c228d..eb111d339 100644 --- a/discord/role.py +++ b/discord/role.py @@ -78,12 +78,13 @@ class Role(Hashable): Indicates if the role can be mentioned by users. """ - __slots__ = ['id', 'name', 'permissions', 'color', 'colour', 'position', - 'managed', 'mentionable', 'hoist', 'server' ] + __slots__ = ('id', 'name', 'permissions', 'color', 'colour', 'position', + 'managed', 'mentionable', 'hoist', 'server', '_state' ) - def __init__(self, **kwargs): - self.server = kwargs.pop('server') - self._update(**kwargs) + def __init__(self, *, server, state, data): + self.server = server + self._state = state + self._update(data) def __str__(self): return self.name @@ -118,15 +119,15 @@ class Role(Hashable): return NotImplemented return not r - def _update(self, **kwargs): - self.id = kwargs.get('id') - self.name = kwargs.get('name') - self.permissions = Permissions(kwargs.get('permissions', 0)) - self.position = kwargs.get('position', 0) - self.colour = Colour(kwargs.get('color', 0)) - self.hoist = kwargs.get('hoist', False) - self.managed = kwargs.get('managed', False) - self.mentionable = kwargs.get('mentionable', False) + def _update(self, data): + self.id = data['id'] + self.name = data['name'] + self.permissions = Permissions(data.get('permissions', 0)) + self.position = data.get('position', 0) + self.colour = Colour(data.get('color', 0)) + self.hoist = data.get('hoist', False) + self.managed = data.get('managed', False) + self.mentionable = data.get('mentionable', False) self.color = self.colour @property diff --git a/discord/server.py b/discord/server.py index 18e9647ee..414d32304 100644 --- a/discord/server.py +++ b/discord/server.py @@ -52,39 +52,39 @@ class Server(Hashable): Attributes ---------- - name : str + name: str The server name. - me : :class:`Member` + me: :class:`Member` Similar to :attr:`Client.user` except an instance of :class:`Member`. This is essentially used to get the member version of yourself. roles A list of :class:`Role` that the server has available. emojis A list of :class:`Emoji` that the server owns. - region : :class:`ServerRegion` + region: :class:`ServerRegion` The region the server belongs on. There is a chance that the region will be a ``str`` if the value is not recognised by the enumerator. - afk_timeout : int + afk_timeout: int The timeout to get sent to the AFK channel. - afk_channel : :class:`Channel` + afk_channel: :class:`Channel` The channel that denotes the AFK channel. None if it doesn't exist. members An iterable of :class:`Member` that are currently on the server. channels An iterable of :class:`Channel` that are currently on the server. - icon : str + icon: str The server's icon. - id : str + id: str The server's ID. - owner : :class:`Member` + owner: :class:`Member` The member who owns the server. - unavailable : bool + unavailable: bool Indicates if the server is unavailable. If this is ``True`` then the reliability of other attributes outside of :meth:`Server.id` is slim and they might all be None. It is best to not do anything with the server if it is unavailable. Check the :func:`on_server_unavailable` and :func:`on_server_available` events. - large : bool + large: bool Indicates if the server is a 'large' server. A large server is defined as having more than ``large_threshold`` count members, which for this library is set to the maximum of 250. @@ -108,17 +108,18 @@ class Server(Hashable): The server's invite splash. """ - __slots__ = ['afk_timeout', 'afk_channel', '_members', '_channels', 'icon', + __slots__ = ('afk_timeout', 'afk_channel', '_members', '_channels', 'icon', 'name', 'id', 'owner', 'unavailable', 'name', 'region', '_default_role', '_default_channel', 'roles', '_member_count', 'large', 'owner_id', 'mfa_level', 'emojis', 'features', - 'verification_level', 'splash' ] + 'verification_level', 'splash' ) - def __init__(self, **kwargs): + def __init__(self, *, data, state): self._channels = {} self.owner = None self._members = {} - self._from_data(kwargs) + self._state = state + self._from_data(data) @property def channels(self): @@ -197,9 +198,9 @@ class Server(Hashable): self.icon = guild.get('icon') self.unavailable = guild.get('unavailable', False) self.id = guild['id'] - self.roles = [Role(server=self, **r) for r in guild.get('roles', [])] + self.roles = [Role(server=self, data=r, state=self._state) for r in guild.get('roles', [])] self.mfa_level = guild.get('mfa_level') - self.emojis = [Emoji(server=self, **r) for r in guild.get('emojis', [])] + self.emojis = [Emoji(server=self, data=r, state=self._state) for r in guild.get('emojis', [])] self.features = guild.get('features', []) self.splash = guild.get('splash') @@ -211,8 +212,7 @@ class Server(Hashable): roles.append(role) mdata['roles'] = roles - member = Member(**mdata) - member.server = self + member = Member(data=mdata, server=self, state=self._state) self._add_member(member) self._sync(guild) @@ -236,18 +236,14 @@ class Server(Hashable): user_id = presence['user']['id'] member = self.get_member(user_id) if member is not None: - member.status = presence['status'] - try: - member.status = Status(member.status) - except: - pass + member.status = try_enum(Status, presence['status']) game = presence.get('game', {}) member.game = Game(**game) if game else None if 'channels' in data: channels = data['channels'] for c in channels: - channel = Channel(server=self, **c) + channel = Channel(server=self, data=c, state=self._state) self._add_channel(channel) @@ -311,7 +307,7 @@ class Server(Hashable): Parameters ----------- - name : str + name: str The name of the member to lookup with an optional discriminator. Returns diff --git a/discord/state.py b/discord/state.py index 60e9e5f58..1af26db8b 100644 --- a/discord/state.py +++ b/discord/state.py @@ -47,18 +47,20 @@ class ListenerType(enum.Enum): chunk = 0 Listener = namedtuple('Listener', ('type', 'future', 'predicate')) +StateContext = namedtuple('StateContext', 'try_insert_user http') log = logging.getLogger(__name__) ReadyState = namedtuple('ReadyState', ('launch', 'servers')) class ConnectionState: - def __init__(self, dispatch, chunker, syncer, max_messages, *, loop): + def __init__(self, *, dispatch, chunker, syncer, http, loop, **options): self.loop = loop - self.max_messages = max_messages + self.max_messages = max(options.get('max_messages', 5000), 100) self.dispatch = dispatch self.chunker = chunker self.syncer = syncer self.is_bot = None self._listeners = [] + self.ctx = StateContext(try_insert_user=self.try_insert_user, http=http) self.clear() def clear(self): @@ -66,6 +68,7 @@ class ConnectionState: self.sequence = None self.session_id = None self._calls = {} + self._users = {} self._servers = {} self._voice_clients = {} self._private_channels = {} @@ -116,6 +119,15 @@ class ConnectionState: for vc in self.voice_clients: vc.main_ws = ws + def try_insert_user(self, data): + # this way is 300% faster than `dict.setdefault`. + user_id = data['id'] + try: + return self._users[user_id] + except KeyError: + self._users[user_id] = user = User(state=self.ctx, data=data) + return user + @property def servers(self): return self._servers.values() @@ -153,7 +165,7 @@ class ConnectionState: return utils.find(lambda m: m.id == msg_id, self.messages) def _add_server_from_data(self, guild): - server = Server(**guild) + server = Server(data=guild, state=self.ctx) Server.me = property(lambda s: s.get_member(self.user.id)) Server.voice_client = property(lambda s: self._get_voice_client(s.id)) self._add_server(server) @@ -207,7 +219,7 @@ class ConnectionState: def parse_ready(self, data): self._ready_state = ReadyState(launch=asyncio.Event(), servers=[]) - self.user = User(**data['user']) + self.user = self.try_insert_user(data['user']) guilds = data.get('guilds') servers = self._ready_state.servers @@ -217,7 +229,7 @@ class ConnectionState: servers.append(server) for pm in data.get('private_channels'): - self._add_private_channel(PrivateChannel(self.user, **pm)) + self._add_private_channel(PrivateChannel(me=self.user, data=pm, state=self.ctx)) compat.create_task(self._delay_ready(), loop=self.loop) @@ -226,7 +238,7 @@ class ConnectionState: def parse_message_create(self, data): channel = self.get_channel(data.get('channel_id')) - message = self._create_message(channel=channel, **data) + message = Message(channel=channel, data=data, state=self.ctx) self.dispatch('message', message) self.messages.append(message) @@ -255,7 +267,7 @@ class ConnectionState: # embed only edit message.embeds = data['embeds'] else: - message._update(channel=message.channel, **data) + message._update(channel=message.channel, data=data) self.dispatch('message_edit', older_message, message) @@ -329,22 +341,11 @@ class ConnectionState: server._add_member(member) old_member = member._copy() - member.status = data.get('status') - try: - member.status = Status(member.status) - except: - pass - - game = data.get('game', {}) - member.game = Game(**game) if game else None - member.name = user.get('username', member.name) - member.avatar = user.get('avatar', member.avatar) - member.discriminator = user.get('discriminator', member.discriminator) - + member._presence_update(data=data, user=user) self.dispatch('member_update', old_member, member) def parse_user_update(self, data): - self.user = User(**data) + self.user = User(state=self.ctx, data=data) def parse_channel_delete(self, data): server = self._get_server(data.get('guild_id')) @@ -361,7 +362,7 @@ class ConnectionState: if channel_type is ChannelType.group: channel = self._get_private_channel(channel_id) old_channel = copy.copy(channel) - channel._update_group(**data) + channel._update_group(data) self.dispatch('channel_update', old_channel, channel) return @@ -370,32 +371,32 @@ class ConnectionState: channel = server.get_channel(channel_id) if channel is not None: old_channel = copy.copy(channel) - channel._update(server=server, **data) + channel._update(server, data) self.dispatch('channel_update', old_channel, channel) def parse_channel_create(self, data): ch_type = try_enum(ChannelType, data.get('type')) channel = None if ch_type in (ChannelType.group, ChannelType.private): - channel = PrivateChannel(self.user, **data) + channel = PrivateChannel(me=self.user, data=data, state=self.ctx) self._add_private_channel(channel) else: server = self._get_server(data.get('guild_id')) if server is not None: - channel = Channel(server=server, **data) + channel = Channel(server=server, state=self.ctx, data=data) server._add_channel(channel) self.dispatch('channel_create', channel) def parse_channel_recipient_add(self, data): channel = self._get_private_channel(data.get('channel_id')) - user = User(**data.get('user', {})) + user = self.try_insert_user(data['user']) channel.recipients.append(user) self.dispatch('group_join', channel, user) def parse_channel_recipient_remove(self, data): channel = self._get_private_channel(data.get('channel_id')) - user = User(**data.get('user', {})) + user = self.try_insert_user(data['user']) try: channel.recipients.remove(user) except ValueError: @@ -411,7 +412,7 @@ class ConnectionState: roles.append(role) data['roles'] = sorted(roles, key=lambda r: int(r.id)) - return Member(server=server, **data) + return Member(server=server, data=data, state=self.ctx) def parse_guild_member_add(self, data): server = self._get_server(data.get('guild_id')) @@ -441,35 +442,18 @@ class ConnectionState: def parse_guild_member_update(self, data): server = self._get_server(data.get('guild_id')) - user_id = data['user']['id'] + user = data['user'] + user_id = user['id'] member = server.get_member(user_id) if member is not None: - user = data['user'] old_member = member._copy() - member.name = user['username'] - member.discriminator = user['discriminator'] - member.avatar = user['avatar'] - member.bot = user.get('bot', False) - - # the nickname change is optional, - # if it isn't in the payload then it didn't change - if 'nick' in data: - member.nick = data['nick'] - - # update the roles - member.roles = [server.default_role] - for role in server.roles: - if role.id in data['roles']: - member.roles.append(role) - - # sort the roles by ID since they can be "randomised" - member.roles.sort(key=lambda r: int(r.id)) + member._update(data, user) self.dispatch('member_update', old_member, member) def parse_guild_emojis_update(self, data): server = self._get_server(data.get('guild_id')) before_emojis = server.emojis - server.emojis = [Emoji(server=server, **e) for e in data.get('emojis', [])] + server.emojis = [Emoji(server=server, data=e, state=self.ctx) for e in data.get('emojis', [])] self.dispatch('server_emojis_update', before_emojis, server.emojis) def _get_create_server(self, data): @@ -584,13 +568,13 @@ class ConnectionState: server = self._get_server(data.get('guild_id')) if server is not None: if 'user' in data: - user = User(**data['user']) + user = self.try_insert_user(data['user']) self.dispatch('member_unban', server, user) def parse_guild_role_create(self, data): - server = self._get_server(data.get('guild_id')) - role_data = data.get('role', {}) - role = Role(server=server, **role_data) + server = self._get_server(data['guild_id']) + role_data = data['role'] + role = Role(server=server, data=role_data, state=self.ctx) server._add_role(role) self.dispatch('server_role_create', role) @@ -609,11 +593,12 @@ class ConnectionState: def parse_guild_role_update(self, data): server = self._get_server(data.get('guild_id')) if server is not None: - role_id = data['role']['id'] + role_data = data['role'] + role_id = role_data['id'] role = utils.find(lambda r: r.id == role_id, server.roles) if role is not None: old_role = copy.copy(role) - role._update(**data['role']) + role._update(role_data) self.dispatch('server_role_update', old_role, role) def parse_guild_members_chunk(self, data): diff --git a/discord/user.py b/discord/user.py index 37f0cc518..919d6deba 100644 --- a/discord/user.py +++ b/discord/user.py @@ -58,14 +58,15 @@ class User: Specifies if the user is a bot account. """ - __slots__ = ['name', 'id', 'discriminator', 'avatar', 'bot'] - - def __init__(self, **kwargs): - self.name = kwargs.get('username') - self.id = kwargs.get('id') - self.discriminator = kwargs.get('discriminator') - self.avatar = kwargs.get('avatar') - self.bot = kwargs.get('bot', False) + __slots__ = ['name', 'id', 'discriminator', 'avatar', 'bot', '_state'] + + def __init__(self, *, state, data): + self._state = state + self.name = data['username'] + self.id = data['id'] + self.discriminator = data['discriminator'] + self.avatar = data['avatar'] + self.bot = data.get('bot', False) def __str__(self): return '{0.name}#{0.discriminator}'.format(self)