From 53ab2631252bf0977446d762f07b3821edb151ee Mon Sep 17 00:00:00 2001 From: Rapptz Date: Mon, 17 Oct 2016 01:10:22 -0400 Subject: [PATCH] Split channel types. This splits them into the following: * DMChannel * GroupChannel * VoiceChannel * TextChannel This also makes the channels "stateful". --- discord/__init__.py | 2 +- discord/abc.py | 277 +++++++++++++++++++++++++ discord/calls.py | 4 +- discord/channel.py | 468 ++++++++++++++++++++++++++++++++----------- discord/client.py | 8 +- discord/errors.py | 6 + discord/iterators.py | 70 ++++--- discord/message.py | 4 +- discord/server.py | 10 +- discord/state.py | 40 ++-- 10 files changed, 715 insertions(+), 174 deletions(-) diff --git a/discord/__init__.py b/discord/__init__.py index 1fd3d83c6..55427cca7 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -21,7 +21,7 @@ from .client import Client, AppInfo, ChannelPermissions from .user import User from .game import Game from .emoji import Emoji -from .channel import Channel, PrivateChannel +from .channel import * from .server import Server from .member import Member, VoiceState from .message import Message diff --git a/discord/abc.py b/discord/abc.py index 2bda266e9..0b42b0e87 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -25,6 +25,12 @@ DEALINGS IN THE SOFTWARE. """ import abc +import io +import os +import asyncio + +from .message import Message +from .iterators import LogsFromIterator class Snowflake(metaclass=abc.ABCMeta): __slots__ = () @@ -75,3 +81,274 @@ class User(metaclass=abc.ABCMeta): return NotImplemented return True return NotImplemented + +class GuildChannel(metaclass=abc.ABCMeta): + __slots__ = () + + @property + @abc.abstractmethod + def mention(self): + raise NotImplementedError + + @abc.abstractmethod + def overwrites_for(self, obj): + raise NotImplementedError + + @abc.abstractmethod + def permissions_for(self, user): + raise NotImplementedError + + @classmethod + def __subclasshook__(cls, C): + if cls is GuildChannel: + if Snowflake.__subclasshook__(C) is NotImplemented: + return NotImplemented + + mro = C.__mro__ + for attr in ('name', 'server', 'overwrites_for', 'permissions_for', 'mention'): + for base in mro: + if attr in base.__dict__: + break + else: + return NotImplemented + return True + return NotImplemented + +class PrivateChannel(metaclass=abc.ABCMeta): + __slots__ = () + + @classmethod + def __subclasshook__(cls, C): + if cls is PrivateChannel: + if Snowflake.__subclasshook__(C) is NotImplemented: + return NotImplemented + + mro = C.__mro__ + for base in mro: + if 'me' in base.__dict__: + return True + return NotImplemented + return NotImplemented + +class MessageChannel(metaclass=abc.ABCMeta): + __slots__ = () + + @abc.abstractmethod + def _get_destination(self): + raise NotImplementedError + + @asyncio.coroutine + def send_message(self, content, *, tts=False): + """|coro| + + Sends a message to the channel with the content given. + + The content must be a type that can convert to a string through ``str(content)``. + + Parameters + ------------ + content + The content of the message to send. + tts: bool + Indicates if the message should be sent using text-to-speech. + + Raises + -------- + HTTPException + Sending the message failed. + Forbidden + You do not have the proper permissions to send the message. + + Returns + --------- + :class:`Message` + The message that was sent. + """ + + channel_id, guild_id = self._get_destination() + content = str(content) + data = yield from self._state.http.send_message(channel_id, content, guild_id=guild_id, tts=tts) + return Message(channel=self, state=self._state, data=data) + + @asyncio.coroutine + def send_typing(self): + """|coro| + + Send a *typing* status to the channel. + + *Typing* status will go away after 10 seconds, or after a message is sent. + """ + + channel_id, _ = self._get_destination() + yield from self._state.http.send_typing(channel_id) + + @asyncio.coroutine + def upload(self, fp, *, filename=None, content=None, tts=False): + """|coro| + + Sends a message to the channel with the file given. + + The ``fp`` parameter should be either a string denoting the location for a + file or a *file-like object*. The *file-like object* passed is **not closed** + at the end of execution. You are responsible for closing it yourself. + + .. note:: + + If the file-like object passed is opened via ``open`` then the modes + 'rb' should be used. + + The ``filename`` parameter is the filename of the file. + If this is not given then it defaults to ``fp.name`` or if ``fp`` is a string + then the ``filename`` will default to the string given. You can overwrite + this value by passing this in. + + Parameters + ------------ + fp + The *file-like object* or file path to send. + filename: str + The filename of the file. Defaults to ``fp.name`` if it's available. + content: str + The content of the message to send along with the file. This is + forced into a string by a ``str(content)`` call. + tts: bool + If the content of the message should be sent with TTS enabled. + + Raises + ------- + HTTPException + Sending the file failed. + + Returns + -------- + :class:`Message` + The message sent. + """ + + channel_id, guild_id = self._get_destination() + + try: + with open(fp, 'rb') as f: + buffer = io.BytesIO(f.read()) + if filename is None: + _, filename = os.path.split(fp) + except TypeError: + buffer = fp + + state = self._state + data = yield from state.http.send_file(channel_id, buffer, guild_id=guild_id, + filename=filename, content=content, tts=tts) + + return Message(channel=self, state=state, data=data) + + @asyncio.coroutine + def get_message(self, id): + """|coro| + + Retrieves a single :class:`Message` from a channel. + + This can only be used by bot accounts. + + Parameters + ------------ + id: int + The message ID to look for. + + Returns + -------- + :class:`Message` + The message asked for. + + Raises + -------- + NotFound + The specified message was not found. + Forbidden + You do not have the permissions required to get a message. + HTTPException + Retrieving the message failed. + """ + + data = yield from self._state.http.get_message(self.id, id) + return Message(channel=self, state=self._state, data=data) + + @asyncio.coroutine + def pins(self): + """|coro| + + Returns a list of :class:`Message` that are currently pinned. + + Raises + ------- + HTTPException + Retrieving the pinned messages failed. + """ + + state = self._state + data = yield from state.http.pins_from(self.id) + return [Message(channel=self, state=state, data=m) for m in data] + + def history(self, *, limit=100, before=None, after=None, around=None, reverse=None): + """Return an async iterator that enables receiving the channel's message history. + + You must have Read Message History permissions to use this. + + All parameters are optional. + + Parameters + ----------- + limit: int + The number of messages to retrieve. + before: :class:`Message` or `datetime` + Retrieve messages before this date or message. + If a date is provided it must be a timezone-naive datetime representing UTC time. + after: :class:`Message` or `datetime` + Retrieve messages after this date or message. + If a date is provided it must be a timezone-naive datetime representing UTC time. + around: :class:`Message` or `datetime` + Retrieve messages around this date or message. + If a date is provided it must be a timezone-naive datetime representing UTC time. + When using this argument, the maximum limit is 101. Note that if the limit is an + even number then this will return at most limit + 1 messages. + reverse: bool + If set to true, return messages in oldest->newest order. If unspecified, + this defaults to ``False`` for most cases. However if passing in a + ``after`` parameter then this is set to ``True``. This avoids getting messages + out of order in the ``after`` case. + + Raises + ------ + Forbidden + You do not have permissions to get channel message history. + HTTPException + The request to get message history failed. + + Yields + ------- + :class:`Message` + The message with the message data parsed. + + Examples + --------- + + Usage :: + + counter = 0 + async for message in channel.history(limit=200): + if message.author == client.user: + counter += 1 + + Python 3.4 Usage :: + + count = 0 + iterator = channel.history(limit=200) + while True: + try: + message = yield from iterator.get() + except discord.NoMoreMessages: + break + else: + if message.author == client.user: + counter += 1 + """ + return LogsFromIterator(self, limit=limit, before=before, after=after, around=around, reverse=reverse) diff --git a/discord/calls.py b/discord/calls.py index 94c55a146..0925f7136 100644 --- a/discord/calls.py +++ b/discord/calls.py @@ -57,7 +57,7 @@ class CallMessage: @property def channel(self): - """:class:`PrivateChannel`\: The private channel associated with this message.""" + """:class:`GroupChannel`\: The private channel associated with this message.""" return self.message.channel @property @@ -131,7 +131,7 @@ class GroupCall: @property def channel(self): - """:class:`PrivateChannel`\: Returns the channel the group call is in.""" + """:class:`GroupChannel`\: Returns the channel the group call is in.""" return self.call.channel def voice_state_for(self, user): diff --git a/discord/channel.py b/discord/channel.py index f79a2d5d1..b1961dd4d 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -23,8 +23,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import copy -from . import utils +from . import utils, abc from .permissions import Permissions, PermissionOverwrite from .enums import ChannelType, try_enum from collections import namedtuple @@ -33,82 +32,54 @@ from .role import Role from .user import User from .member import Member +import copy +import asyncio + +__all__ = ('TextChannel', 'VoiceChannel', 'DMChannel', 'GroupChannel', '_channel_factory') + Overwrites = namedtuple('Overwrites', 'id allow deny type') -class Channel(Hashable): - """Represents a Discord server channel. +class CommonGuildChannel(Hashable): + __slots__ = () - Supported Operations: + def __str__(self): + return self.name - +-----------+---------------------------------------+ - | Operation | Description | - +===========+=======================================+ - | x == y | Checks if two channels are equal. | - +-----------+---------------------------------------+ - | x != y | Checks if two channels are not equal. | - +-----------+---------------------------------------+ - | hash(x) | Returns the channel's hash. | - +-----------+---------------------------------------+ - | str(x) | Returns the channel's name. | - +-----------+---------------------------------------+ + @asyncio.coroutine + def _move(self, position): + if position < 0: + raise InvalidArgument('Channel position cannot be less than 0.') - Attributes - ----------- - name: str - The channel name. - server: :class:`Server` - The server the channel belongs to. - id: int - The channel ID. - topic: Optional[str] - The channel's topic. None if it doesn't exist. - is_private: bool - ``True`` if the channel is a private channel (i.e. PM). ``False`` in this case. - position: int - The position in the channel list. This is a number that starts at 0. e.g. the - top channel is position 0. The position varies depending on being a voice channel - or a text channel, so a 0 position voice channel is on top of the voice channel - list. - type: :class:`ChannelType` - The channel type. There is a chance that the type will be ``str`` if - the channel type is not within the ones recognised by the enumerator. - bitrate: int - The channel's preferred audio bitrate in bits per second. - voice_members - A list of :class:`Members` that are currently inside this voice channel. - If :attr:`type` is not :attr:`ChannelType.voice` then this is always an empty array. - user_limit: int - The channel's limit for number of members that can be in a voice channel. - """ + http = self._state.http + url = '{0}/{1.server.id}/channels'.format(http.GUILDS, self) + channels = [c for c in self.server.channels if isinstance(c, type(self))] - __slots__ = ( 'voice_members', 'name', 'id', 'server', 'topic', - 'type', 'bitrate', 'user_limit', '_state', 'position', - '_permission_overwrites' ) + if position >= len(channels): + raise InvalidArgument('Channel position cannot be greater than {}'.format(len(channels) - 1)) - def __init__(self, *, state, server, data): - self._state = state - self.id = int(data['id']) - self._update(server, data) - self.voice_members = [] + channels.sort(key=lambda c: c.position) - def __str__(self): - return self.name + try: + # remove ourselves from the channel list + channels.remove(self) + except ValueError: + # not there somehow lol + return + else: + # add ourselves at our designated position + channels.insert(position, self) - def _update(self, server, data): - self.server = server - self.name = data['name'] - self.topic = data.get('topic') - self.position = data['position'] - self.bitrate = data.get('bitrate') - self.type = data['type'] - self.user_limit = data.get('user_limit') - self._permission_overwrites = [] + payload = [{'id': c.id, 'position': index } for index, c in enumerate(channels)] + yield from http.patch(url, json=payload, bucket='move_channel') + + def _fill_overwrites(self, data): + self._overwrites = [] everyone_index = 0 everyone_id = self.server.id for index, overridden in enumerate(data.get('permission_overwrites', [])): overridden_id = int(overridden.pop('id')) - self._permission_overwrites.append(Overwrites(id=overridden_id, **overridden)) + self._overwrites.append(Overwrites(id=overridden_id, **overridden)) if overridden['type'] == 'member': continue @@ -122,7 +93,7 @@ class Channel(Hashable): everyone_index = index # do the swap - tmp = self._permission_overwrites + tmp = self._overwrites if tmp: tmp[everyone_index], tmp[0] = tmp[0], tmp[everyone_index] @@ -131,7 +102,7 @@ class Channel(Hashable): """Returns a list of :class:`Roles` that have been overridden from their default values in the :attr:`Server.roles` attribute.""" ret = [] - for overwrite in filter(lambda o: o.type == 'role', self._permission_overwrites): + for overwrite in filter(lambda o: o.type == 'role', self._overwrites): role = utils.get(self.server.roles, id=overwrite.id) if role is None: continue @@ -146,10 +117,6 @@ class Channel(Hashable): """bool : Indicates if this is the default channel for the :class:`Server` it belongs to.""" return self.server.id == self.id - @property - def is_private(self): - return False - @property def mention(self): """str : The string that allows you to mention the channel.""" @@ -182,7 +149,7 @@ class Channel(Hashable): else: predicate = lambda p: True - for overwrite in filter(predicate, self._permission_overwrites): + for overwrite in filter(predicate, self._overwrites): if overwrite.id == obj.id: allow = Permissions(overwrite.allow) deny = Permissions(overwrite.deny) @@ -276,7 +243,7 @@ class Channel(Hashable): allows = 0 # Apply channel specific role permission overwrites - for overwrite in self._permission_overwrites: + for overwrite in self._overwrites: if overwrite.type == 'role' and overwrite.id in member_role_ids: denies |= overwrite.deny allows |= overwrite.allow @@ -284,7 +251,7 @@ class Channel(Hashable): base.handle_overwrite(allow=allows, deny=denies) # Apply member specific permission overwrites - for overwrite in self._permission_overwrites: + for overwrite in self._overwrites: if overwrite.type == 'member' and overwrite.id == member.id: base.handle_overwrite(allow=overwrite.allow, deny=overwrite.deny) break @@ -307,14 +274,286 @@ class Channel(Hashable): base.value &= ~denied.value # text channels do not have voice related permissions - if self.type is ChannelType.text: + if isinstance(self, TextChannel): denied = Permissions.voice() base.value &= ~denied.value return base -class PrivateChannel(Hashable): - """Represents a Discord private channel. + @asyncio.coroutine + def delete(self): + """|coro| + + Deletes the channel. + + You must have Manage Channel permission to use this. + + Raises + ------- + Forbidden + You do not have proper permissions to delete the channel. + NotFound + The channel was not found or was already deleted. + HTTPException + Deleting the channel failed. + """ + yield from self._state.http.delete_channel(self.id) + +class TextChannel(abc.MessageChannel, CommonGuildChannel): + """Represents a Discord server text channel. + + Supported Operations: + + +-----------+---------------------------------------+ + | Operation | Description | + +===========+=======================================+ + | x == y | Checks if two channels are equal. | + +-----------+---------------------------------------+ + | x != y | Checks if two channels are not equal. | + +-----------+---------------------------------------+ + | hash(x) | Returns the channel's hash. | + +-----------+---------------------------------------+ + | str(x) | Returns the channel's name. | + +-----------+---------------------------------------+ + + Attributes + ----------- + name: str + The channel name. + server: :class:`Server` + The server the channel belongs to. + id: int + The channel ID. + topic: Optional[str] + The channel's topic. None if it doesn't exist. + position: int + The position in the channel list. This is a number that starts at 0. e.g. the + top channel is position 0. + """ + + __slots__ = ( 'name', 'id', 'server', 'topic', '_state', + 'position', '_overwrites' ) + + def __init__(self, *, state, server, data): + self._state = state + self.id = int(data['id']) + self._update(server, data) + + def _update(self, server, data): + self.server = server + self.name = data['name'] + self.topic = data.get('topic') + self.position = data['position'] + self._fill_overwrites(data) + + def _get_destination(self): + return self.id, self.server.id + + @asyncio.coroutine + def edit(self, **options): + """|coro| + + Edits the channel. + + You must have the Manage Channel permission to use this. + + Parameters + ---------- + name: str + The new channel name. + topic: str + The new channel's topic. + position: int + The new channel's position. + + Raises + ------ + InvalidArgument + If position is less than 0 or greater than the number of channels. + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + try: + position = options.pop('position') + except KeyError: + pass + else: + yield from self._move(position) + self.position = position + + if options: + data = yield from self._state.http.edit_channel(self.id, **options) + self._update(self.server, data) + +class VoiceChannel(CommonGuildChannel): + """Represents a Discord server voice channel. + + Supported Operations: + + +-----------+---------------------------------------+ + | Operation | Description | + +===========+=======================================+ + | x == y | Checks if two channels are equal. | + +-----------+---------------------------------------+ + | x != y | Checks if two channels are not equal. | + +-----------+---------------------------------------+ + | hash(x) | Returns the channel's hash. | + +-----------+---------------------------------------+ + | str(x) | Returns the channel's name. | + +-----------+---------------------------------------+ + + Attributes + ----------- + name: str + The channel name. + server: :class:`Server` + The server the channel belongs to. + id: int + The channel ID. + position: int + The position in the channel list. This is a number that starts at 0. e.g. the + top channel is position 0. + bitrate: int + The channel's preferred audio bitrate in bits per second. + voice_members + A list of :class:`Members` that are currently inside this voice channel. + user_limit: int + The channel's limit for number of members that can be in a voice channel. + """ + + __slots__ = ( 'voice_members', 'name', 'id', 'server', 'bitrate', + 'user_limit', '_state', 'position', '_overwrites' ) + + def __init__(self, *, state, server, data): + self._state = state + self.id = int(data['id']) + self._update(server, data) + self.voice_members = [] + + def _update(self, server, data): + self.server = server + self.name = data['name'] + self.position = data['position'] + self.bitrate = data.get('bitrate') + self.user_limit = data.get('user_limit') + self._fill_overwrites(data) + + @asyncio.coroutine + def edit(self, **options): + """|coro| + + Edits the channel. + + You must have the Manage Channel permission to use this. + + Parameters + ---------- + bitrate: int + The new channel's bitrate. + user_limit: int + The new channel's user limit. + position: int + The new channel's position. + + Raises + ------ + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + + try: + position = options.pop('position') + except KeyError: + pass + else: + yield from self._move(position) + self.position = position + + if options: + data = yield from self._state.http.edit_channel(self.id, **options) + self._update(self.server, data) + +class DMChannel(abc.MessageChannel, Hashable): + """Represents a Discord direct message channel. + + Supported Operations: + + +-----------+-------------------------------------------------+ + | Operation | Description | + +===========+=================================================+ + | x == y | Checks if two channels are equal. | + +-----------+-------------------------------------------------+ + | x != y | Checks if two channels are not equal. | + +-----------+-------------------------------------------------+ + | hash(x) | Returns the channel's hash. | + +-----------+-------------------------------------------------+ + | str(x) | Returns a string representation of the channel | + +-----------+-------------------------------------------------+ + + Attributes + ---------- + recipient: :class:`User` + The user you are participating with in the direct message channel. + me: :class:`User` + The user presenting yourself. + id: int + The direct message channel ID. + """ + + __slots__ = ('id', 'recipient', 'me', '_state') + + def __init__(self, *, me, state, data): + self._state = state + self.recipient = state.try_insert_user(data['recipients'][0]) + self.me = me + self.id = int(data['id']) + + def _get_destination(self): + return self.id, None + + def __str__(self): + return 'Direct Message with %s' % self.recipient + + @property + def created_at(self): + """Returns the direct message channel's creation time in UTC.""" + return utils.snowflake_time(self.id) + + def permissions_for(self, user=None): + """Handles permission resolution for a :class:`User`. + + This function is there for compatibility with other channel types. + + Actual direct messages do not really have the concept of permissions. + + This returns all the Text related permissions set to true except: + + - send_tts_messages: You cannot send TTS messages in a DM. + - manage_messages: You cannot delete others messages in a DM. + + Parameters + ----------- + user: :class:`User` + The user to check permissions for. This parameter is ignored + but kept for compatibility. + + Returns + -------- + :class:`Permissions` + The resolved permissions. + """ + + base = Permissions.text() + base.send_tts_messages = False + base.manage_messages = False + return base + +class GroupChannel(abc.MessageChannel, Hashable): + """Represents a Discord group channel. Supported Operations: @@ -333,50 +572,42 @@ class PrivateChannel(Hashable): Attributes ---------- recipients: list of :class:`User` - The users you are participating with in the private channel. + The users you are participating with in the group channel. me: :class:`User` The user presenting yourself. id: int - The private channel ID. - is_private: bool - ``True`` if the channel is a private channel (i.e. PM). ``True`` in this case. - type: :class:`ChannelType` - The type of private channel. - owner: Optional[:class:`User`] - The user that owns the private channel. If the channel type is not - :attr:`ChannelType.group` then this is always ``None``. + The group channel ID. + owner: :class:`User` + The user that owns the group channel. icon: Optional[str] - The private channel's icon hash. If the channel type is not - :attr:`ChannelType.group` then this is always ``None``. + The group channel's icon hash if provided. name: Optional[str] - The private channel's name. If the channel type is not - :attr:`ChannelType.group` then this is always ``None``. + The group channel's name if provided. """ - __slots__ = ('id', 'recipients', 'type', 'owner', 'icon', 'name', 'me', '_state') + __slots__ = ('id', 'recipients', 'owner', 'icon', 'name', 'me', '_state') def __init__(self, *, me, state, data): self._state = state self.recipients = [state.try_insert_user(u) for u in data['recipients']] self.id = int(data['id']) self.me = me - self.type = try_enum(ChannelType, data['type']) self._update_group(data) def _update_group(self, data): owner_id = utils._get_as_snowflake(data, 'owner_id') self.icon = data.get('icon') self.name = data.get('name') - self.owner = utils.find(lambda u: u.id == owner_id, self.recipients) - @property - def is_private(self): - return True + if owner_id == self.me.id: + self.owner = self.me + else: + self.owner = utils.find(lambda u: u.id == owner_id, self.recipients) - def __str__(self): - if self.type is ChannelType.private: - return 'Direct Message with {0.name}'.format(self.user) + def _get_destination(self): + return self.id, None + def __str__(self): if self.name: return self.name @@ -385,15 +616,6 @@ class PrivateChannel(Hashable): return ', '.join(map(lambda x: x.name, self.recipients)) - @property - def user(self): - """A property that returns the first recipient of the private channel. - - This is mainly for compatibility and ease of use with old style private - channels that had a single recipient. - """ - return self.recipients[0] - @property def icon_url(self): """Returns the channel's icon URL if available or an empty string otherwise.""" @@ -404,27 +626,26 @@ class PrivateChannel(Hashable): @property def created_at(self): - """Returns the private channel's creation time in UTC.""" + """Returns the channel's creation time in UTC.""" return utils.snowflake_time(self.id) def permissions_for(self, user): """Handles permission resolution for a :class:`User`. - This function is there for compatibility with :class:`Channel`. + This function is there for compatibility with other channel types. - Actual private messages do not really have the concept of permissions. + Actual direct messages do not really have the concept of permissions. This returns all the Text related permissions set to true except: - - send_tts_messages: You cannot send TTS messages in a PM. - - manage_messages: You cannot delete others messages in a PM. + - send_tts_messages: You cannot send TTS messages in a DM. + - manage_messages: You cannot delete others messages in a DM. - This also handles permissions for :attr:`ChannelType.group` channels - such as kicking or mentioning everyone. + This also checks the kick_members permission if the user is the owner. Parameters ----------- - user : :class:`User` + user: :class:`User` The user to check permissions for. Returns @@ -436,11 +657,22 @@ class PrivateChannel(Hashable): base = Permissions.text() base.send_tts_messages = False base.manage_messages = False - base.mention_everyone = self.type is ChannelType.group + base.mention_everyone = True - if user == self.owner: + if user.id == self.owner.id: base.kick_members = True return base - +def _channel_factory(channel_type): + value = try_enum(ChannelType, channel_type) + if value is ChannelType.text: + return TextChannel, value + elif value is ChannelType.voice: + return VoiceChannel, value + elif value is ChannelType.private: + return DMChannel, value + elif value is ChannelType.group: + return GroupChannel, value + else: + return None, value diff --git a/discord/client.py b/discord/client.py index b1dd1c222..94aaa6c4f 100644 --- a/discord/client.py +++ b/discord/client.py @@ -27,7 +27,7 @@ DEALINGS IN THE SOFTWARE. from . import __version__ as library_version from .user import User from .member import Member -from .channel import Channel, PrivateChannel +from .channel import * from .server import Server from .message import Message from .invite import Invite @@ -261,9 +261,9 @@ class Client: @asyncio.coroutine def _resolve_destination(self, destination): - if isinstance(destination, Channel): + if isinstance(destination, TextChannel): return destination.id, destination.server.id - elif isinstance(destination, PrivateChannel): + elif isinstance(destination, DMChannel): return destination.id, None elif isinstance(destination, Server): return destination.id, destination.id @@ -283,7 +283,7 @@ class Client: # couldn't find it in cache so YOLO return destination.id, destination.id else: - fmt = 'Destination must be Channel, PrivateChannel, User, or Object. Received {0.__class__.__name__}' + fmt = 'Destination must be TextChannel, DMChannel, User, or Object. Received {0.__class__.__name__}' raise InvalidArgument(fmt.format(destination)) def __getattr__(self, name): diff --git a/discord/errors.py b/discord/errors.py index 46d5e940c..5449b77ef 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -38,6 +38,12 @@ class ClientException(DiscordException): """ pass +class NoMoreMessages(DiscordException): + """Exception that is thrown when a ``history`` operation has no more + messages. This is only exposed for Python 3.4 only. + """ + pass + class GatewayNotFound(DiscordException): """An exception that is usually thrown when the gateway hub for the :class:`Client` websocket is not found.""" diff --git a/discord/iterators.py b/discord/iterators.py index 63a8776d3..91470d801 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -27,23 +27,26 @@ DEALINGS IN THE SOFTWARE. import sys import asyncio import aiohttp +import datetime + +from .errors import NoMoreMessages +from .utils import time_snowflake from .message import Message from .object import Object PY35 = sys.version_info >= (3, 5) - class LogsFromIterator: - """Iterator for recieving logs. + """Iterator for receiving logs. - The messages endpoint has two behaviors we care about here: + The messages endpoint has two behaviours we care about here: If `before` is specified, the messages endpoint returns the `limit` newest messages before `before`, sorted with newest first. For filling over - 100 messages, update the `before` parameter to the oldest message recieved. + 100 messages, update the `before` parameter to the oldest message received. Messages will be returned in order by time. If `after` is specified, it returns the `limit` oldest messages after `after`, sorted with newest first. For filling over 100 messages, update the - `after` parameter to the newest message recieved. If messages are not + `after` parameter to the newest message received. If messages are not reversed, they will be out of order (99-0, 199-100, so on) A note that if both before and after are specified, before is ignored by the @@ -51,8 +54,7 @@ class LogsFromIterator: Parameters ----------- - client : class:`Client` - channel : class:`Channel` + channel: class:`Channel` Channel from which to request logs limit : int Maximum number of messages to retrieve @@ -63,24 +65,37 @@ class LogsFromIterator: around : :class:`Message` or id-like Message around which all messages must be. Limit max 101. Note that if limit is an even number, this will return at most limit+1 messages. - reverse : bool + reverse: bool If set to true, return messages in oldest->newest order. Recommended when using with "after" queries with limit over 100, otherwise messages - will be out of order. Defaults to False for backwards compatability. + will be out of order. """ - def __init__(self, client, channel, limit, - before=None, after=None, around=None, reverse=False): - self.client = client + def __init__(self, channel, limit, + before=None, after=None, around=None, reverse=None): + + if isinstance(before, datetime.datetime): + before = Object(id=time_snowflake(before, high=False)) + if isinstance(after, datetime.datetime): + after = Object(id=time_snowflake(after, high=True)) + if isinstance(around, datetime.datetime): + around = Object(id=time_snowflake(around)) + self.channel = channel + self.ctx = channel._state + self.logs_from = channel._state.http.logs_from self.limit = limit self.before = before self.after = after self.around = around - self.reverse = reverse + + if reverse is None: + self.reverse = after is not None + else: + self.reverse = reverse + self._filter = None # message dict -> bool self.messages = asyncio.Queue() - self.ctx = client.connection.ctx if self.around: if self.limit > 101: @@ -92,29 +107,32 @@ class LogsFromIterator: self._retrieve_messages = self._retrieve_messages_around_strategy if self.before and self.after: - self._filter = lambda m: self.after.id < m['id'] < self.before.id + self._filter = lambda m: self.after.id < int(m['id']) < self.before.id elif self.before: - self._filter = lambda m: m['id'] < self.before.id + self._filter = lambda m: int(m['id']) < self.before.id elif self.after: - self._filter = lambda m: self.after.id < m['id'] + self._filter = lambda m: self.after.id < int(m['id']) elif self.before and self.after: if self.reverse: self._retrieve_messages = self._retrieve_messages_after_strategy - self._filter = lambda m: m['id'] < self.before.id + self._filter = lambda m: int(m['id']) < self.before.id else: self._retrieve_messages = self._retrieve_messages_before_strategy - self._filter = lambda m: m['id'] > self.after.id + self._filter = lambda m: int(m['id']) > self.after.id elif self.after: self._retrieve_messages = self._retrieve_messages_after_strategy else: self._retrieve_messages = self._retrieve_messages_before_strategy @asyncio.coroutine - def iterate(self): + def get(self): if self.messages.empty(): yield from self.fill_messages() - return self.messages.get_nowait() + try: + return self.messages.get_nowait() + except asyncio.QueueEmpty: + raise NoMoreMessages() @asyncio.coroutine def fill_messages(self): @@ -136,7 +154,7 @@ class LogsFromIterator: @asyncio.coroutine def _retrieve_messages_before_strategy(self, retrieve): """Retrieve messages using before parameter.""" - data = yield from self.client._logs_from(self.channel, retrieve, before=self.before) + data = yield from self.logs_from(self.channel.id, retrieve, before=getattr(self.before, 'id', None)) if len(data): self.limit -= retrieve self.before = Object(id=int(data[-1]['id'])) @@ -145,7 +163,7 @@ class LogsFromIterator: @asyncio.coroutine def _retrieve_messages_after_strategy(self, retrieve): """Retrieve messages using after parameter.""" - data = yield from self.client._logs_from(self.channel, retrieve, after=self.after) + data = yield from self.logs_from(self.channel.id, retrieve, after=getattr(self.after, 'id', None)) if len(data): self.limit -= retrieve self.after = Object(id=int(data[0]['id'])) @@ -155,7 +173,7 @@ class LogsFromIterator: def _retrieve_messages_around_strategy(self, retrieve): """Retrieve messages using around parameter.""" if self.around: - data = yield from self.client._logs_from(self.channel, retrieve, around=self.around) + data = yield from self.logs_from(self.channel.id, retrieve, around=getattr(self.around, 'id', None)) self.around = None return data return [] @@ -168,9 +186,9 @@ class LogsFromIterator: @asyncio.coroutine def __anext__(self): try: - msg = yield from self.iterate() + msg = yield from self.get() return msg - except asyncio.QueueEmpty: + except NoMoreMessages: # if we're still empty at this point... # we didn't get any new messages so stop looping raise StopAsyncIteration() diff --git a/discord/message.py b/discord/message.py index 28ab18d1b..c2caaf9d9 100644 --- a/discord/message.py +++ b/discord/message.py @@ -24,9 +24,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from . import utils from .user import User from .reaction import Reaction +from . import utils, abc from .object import Object from .calls import CallMessage import re @@ -292,7 +292,7 @@ class Message: self.channel.is_private = True return - if not self.channel.is_private: + if isinstance(self.channel, abc.GuildChannel): self.server = self.channel.server found = self.server.get_member(self.author.id) if found is not None: diff --git a/discord/server.py b/discord/server.py index 7d4cb4683..d1523d6fb 100644 --- a/discord/server.py +++ b/discord/server.py @@ -29,8 +29,8 @@ from .role import Role from .member import Member, VoiceState from .emoji import Emoji from .game import Game -from .channel import Channel -from .enums import ServerRegion, Status, try_enum, VerificationLevel +from .channel import * +from .enums import ServerRegion, Status, ChannelType, try_enum, VerificationLevel from .mixins import Hashable import copy @@ -273,7 +273,11 @@ class Server(Hashable): if 'channels' in data: channels = data['channels'] for c in channels: - channel = Channel(server=self, data=c, state=self._state) + if c['type'] == ChannelType.text.value: + channel = TextChannel(server=self, data=c, state=self._state) + else: + channel = VoiceChannel(server=self, data=c, state=self._state) + self._add_channel(channel) @utils.cached_slot_property('_default_role') diff --git a/discord/state.py b/discord/state.py index ad9bb172c..b33d11aee 100644 --- a/discord/state.py +++ b/discord/state.py @@ -30,7 +30,7 @@ from .game import Game from .emoji import Emoji from .reaction import Reaction from .message import Message -from .channel import Channel, PrivateChannel +from .channel import * from .member import Member from .role import Role from . import utils, compat @@ -153,13 +153,13 @@ class ConnectionState: def _add_private_channel(self, channel): self._private_channels[channel.id] = channel - if channel.type is ChannelType.private: - self._private_channels_by_user[channel.user.id] = channel + if isinstance(channel, DMChannel): + self._private_channels_by_user[channel.recipient.id] = channel def _remove_private_channel(self, channel): self._private_channels.pop(channel.id, None) - if channel.type is ChannelType.private: - self._private_channels_by_user.pop(channel.user.id, None) + if isinstance(channel, DMChannel): + self._private_channels_by_user.pop(channel.recipient.id, None) def _get_message(self, msg_id): return utils.find(lambda m: m.id == msg_id, self.messages) @@ -229,7 +229,8 @@ class ConnectionState: servers.append(server) for pm in data.get('private_channels'): - self._add_private_channel(PrivateChannel(me=self.user, data=pm, state=self.ctx)) + factory, _ = _channel_factory(pm['type']) + self._add_private_channel(factory(me=self.user, data=pm, state=self.ctx)) compat.create_task(self._delay_ready(), loop=self.loop) @@ -348,13 +349,18 @@ class ConnectionState: self.user = User(state=self.ctx, data=data) def parse_channel_delete(self, data): - server = self._get_server(int(data['guild_id'])) + server = self._get_server(utils._get_as_snowflake(data, 'guild_id')) + channel_id = int(data['id']) if server is not None: - channel_id = data.get('id') channel = server.get_channel(channel_id) if channel is not None: server._remove_channel(channel) self.dispatch('channel_delete', channel) + else: + # the reason we're doing this is so it's also removed from the + # private channel by user cache as well + channel = self._get_private_channel(channel_id) + self._remove_private_channel(channel) def parse_channel_update(self, data): channel_type = try_enum(ChannelType, data.get('type')) @@ -375,15 +381,15 @@ class ConnectionState: self.dispatch('channel_update', old_channel, channel) def parse_channel_create(self, data): - ch_type = try_enum(ChannelType, data.get('type')) + factory, ch_type = _channel_factory(data['type']) channel = None if ch_type in (ChannelType.group, ChannelType.private): - channel = PrivateChannel(me=self.user, data=data, state=self.ctx) + channel = factory(me=self.user, data=data, state=self.ctx) self._add_private_channel(channel) else: server = self._get_server(utils._get_as_snowflake(data, 'guild_id')) if server is not None: - channel = Channel(server=server, state=self.ctx, data=data) + channel = factory(server=server, state=self.ctx, data=data) server._add_channel(channel) self.dispatch('channel_create', channel) @@ -638,14 +644,12 @@ class ConnectionState: if channel is not None: member = None user_id = utils._get_as_snowflake(data, 'user_id') - is_private = getattr(channel, 'is_private', None) - if is_private == None: - return - - if is_private: - member = channel.user - else: + if isinstance(channel, DMChannel): + member = channel.recipient + elif isinstance(channel, TextChannel): member = channel.server.get_member(user_id) + elif isinstance(channel, GroupChannel): + member = utils.find(lambda x: x.id == user_id, channel.recipients) if member is not None: timestamp = datetime.datetime.utcfromtimestamp(data.get('timestamp'))