Browse Source

Split channel types.

This splits them into the following:

* DMChannel
* GroupChannel
* VoiceChannel
* TextChannel

This also makes the channels "stateful".
pull/447/head
Rapptz 9 years ago
parent
commit
53ab263125
  1. 2
      discord/__init__.py
  2. 277
      discord/abc.py
  3. 4
      discord/calls.py
  4. 468
      discord/channel.py
  5. 8
      discord/client.py
  6. 6
      discord/errors.py
  7. 70
      discord/iterators.py
  8. 4
      discord/message.py
  9. 10
      discord/server.py
  10. 40
      discord/state.py

2
discord/__init__.py

@ -21,7 +21,7 @@ from .client import Client, AppInfo, ChannelPermissions
from .user import User from .user import User
from .game import Game from .game import Game
from .emoji import Emoji from .emoji import Emoji
from .channel import Channel, PrivateChannel from .channel import *
from .server import Server from .server import Server
from .member import Member, VoiceState from .member import Member, VoiceState
from .message import Message from .message import Message

277
discord/abc.py

@ -25,6 +25,12 @@ DEALINGS IN THE SOFTWARE.
""" """
import abc import abc
import io
import os
import asyncio
from .message import Message
from .iterators import LogsFromIterator
class Snowflake(metaclass=abc.ABCMeta): class Snowflake(metaclass=abc.ABCMeta):
__slots__ = () __slots__ = ()
@ -75,3 +81,274 @@ class User(metaclass=abc.ABCMeta):
return NotImplemented return NotImplemented
return True return True
return NotImplemented return NotImplemented
class GuildChannel(metaclass=abc.ABCMeta):
__slots__ = ()
@property
@abc.abstractmethod
def mention(self):
raise NotImplementedError
@abc.abstractmethod
def overwrites_for(self, obj):
raise NotImplementedError
@abc.abstractmethod
def permissions_for(self, user):
raise NotImplementedError
@classmethod
def __subclasshook__(cls, C):
if cls is GuildChannel:
if Snowflake.__subclasshook__(C) is NotImplemented:
return NotImplemented
mro = C.__mro__
for attr in ('name', 'server', 'overwrites_for', 'permissions_for', 'mention'):
for base in mro:
if attr in base.__dict__:
break
else:
return NotImplemented
return True
return NotImplemented
class PrivateChannel(metaclass=abc.ABCMeta):
__slots__ = ()
@classmethod
def __subclasshook__(cls, C):
if cls is PrivateChannel:
if Snowflake.__subclasshook__(C) is NotImplemented:
return NotImplemented
mro = C.__mro__
for base in mro:
if 'me' in base.__dict__:
return True
return NotImplemented
return NotImplemented
class MessageChannel(metaclass=abc.ABCMeta):
__slots__ = ()
@abc.abstractmethod
def _get_destination(self):
raise NotImplementedError
@asyncio.coroutine
def send_message(self, content, *, tts=False):
"""|coro|
Sends a message to the channel with the content given.
The content must be a type that can convert to a string through ``str(content)``.
Parameters
------------
content
The content of the message to send.
tts: bool
Indicates if the message should be sent using text-to-speech.
Raises
--------
HTTPException
Sending the message failed.
Forbidden
You do not have the proper permissions to send the message.
Returns
---------
:class:`Message`
The message that was sent.
"""
channel_id, guild_id = self._get_destination()
content = str(content)
data = yield from self._state.http.send_message(channel_id, content, guild_id=guild_id, tts=tts)
return Message(channel=self, state=self._state, data=data)
@asyncio.coroutine
def send_typing(self):
"""|coro|
Send a *typing* status to the channel.
*Typing* status will go away after 10 seconds, or after a message is sent.
"""
channel_id, _ = self._get_destination()
yield from self._state.http.send_typing(channel_id)
@asyncio.coroutine
def upload(self, fp, *, filename=None, content=None, tts=False):
"""|coro|
Sends a message to the channel with the file given.
The ``fp`` parameter should be either a string denoting the location for a
file or a *file-like object*. The *file-like object* passed is **not closed**
at the end of execution. You are responsible for closing it yourself.
.. note::
If the file-like object passed is opened via ``open`` then the modes
'rb' should be used.
The ``filename`` parameter is the filename of the file.
If this is not given then it defaults to ``fp.name`` or if ``fp`` is a string
then the ``filename`` will default to the string given. You can overwrite
this value by passing this in.
Parameters
------------
fp
The *file-like object* or file path to send.
filename: str
The filename of the file. Defaults to ``fp.name`` if it's available.
content: str
The content of the message to send along with the file. This is
forced into a string by a ``str(content)`` call.
tts: bool
If the content of the message should be sent with TTS enabled.
Raises
-------
HTTPException
Sending the file failed.
Returns
--------
:class:`Message`
The message sent.
"""
channel_id, guild_id = self._get_destination()
try:
with open(fp, 'rb') as f:
buffer = io.BytesIO(f.read())
if filename is None:
_, filename = os.path.split(fp)
except TypeError:
buffer = fp
state = self._state
data = yield from state.http.send_file(channel_id, buffer, guild_id=guild_id,
filename=filename, content=content, tts=tts)
return Message(channel=self, state=state, data=data)
@asyncio.coroutine
def get_message(self, id):
"""|coro|
Retrieves a single :class:`Message` from a channel.
This can only be used by bot accounts.
Parameters
------------
id: int
The message ID to look for.
Returns
--------
:class:`Message`
The message asked for.
Raises
--------
NotFound
The specified message was not found.
Forbidden
You do not have the permissions required to get a message.
HTTPException
Retrieving the message failed.
"""
data = yield from self._state.http.get_message(self.id, id)
return Message(channel=self, state=self._state, data=data)
@asyncio.coroutine
def pins(self):
"""|coro|
Returns a list of :class:`Message` that are currently pinned.
Raises
-------
HTTPException
Retrieving the pinned messages failed.
"""
state = self._state
data = yield from state.http.pins_from(self.id)
return [Message(channel=self, state=state, data=m) for m in data]
def history(self, *, limit=100, before=None, after=None, around=None, reverse=None):
"""Return an async iterator that enables receiving the channel's message history.
You must have Read Message History permissions to use this.
All parameters are optional.
Parameters
-----------
limit: int
The number of messages to retrieve.
before: :class:`Message` or `datetime`
Retrieve messages before this date or message.
If a date is provided it must be a timezone-naive datetime representing UTC time.
after: :class:`Message` or `datetime`
Retrieve messages after this date or message.
If a date is provided it must be a timezone-naive datetime representing UTC time.
around: :class:`Message` or `datetime`
Retrieve messages around this date or message.
If a date is provided it must be a timezone-naive datetime representing UTC time.
When using this argument, the maximum limit is 101. Note that if the limit is an
even number then this will return at most limit + 1 messages.
reverse: bool
If set to true, return messages in oldest->newest order. If unspecified,
this defaults to ``False`` for most cases. However if passing in a
``after`` parameter then this is set to ``True``. This avoids getting messages
out of order in the ``after`` case.
Raises
------
Forbidden
You do not have permissions to get channel message history.
HTTPException
The request to get message history failed.
Yields
-------
:class:`Message`
The message with the message data parsed.
Examples
---------
Usage ::
counter = 0
async for message in channel.history(limit=200):
if message.author == client.user:
counter += 1
Python 3.4 Usage ::
count = 0
iterator = channel.history(limit=200)
while True:
try:
message = yield from iterator.get()
except discord.NoMoreMessages:
break
else:
if message.author == client.user:
counter += 1
"""
return LogsFromIterator(self, limit=limit, before=before, after=after, around=around, reverse=reverse)

4
discord/calls.py

@ -57,7 +57,7 @@ class CallMessage:
@property @property
def channel(self): def channel(self):
""":class:`PrivateChannel`\: The private channel associated with this message.""" """:class:`GroupChannel`\: The private channel associated with this message."""
return self.message.channel return self.message.channel
@property @property
@ -131,7 +131,7 @@ class GroupCall:
@property @property
def channel(self): def channel(self):
""":class:`PrivateChannel`\: Returns the channel the group call is in.""" """:class:`GroupChannel`\: Returns the channel the group call is in."""
return self.call.channel return self.call.channel
def voice_state_for(self, user): def voice_state_for(self, user):

468
discord/channel.py

@ -23,8 +23,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
import copy from . import utils, abc
from . import utils
from .permissions import Permissions, PermissionOverwrite from .permissions import Permissions, PermissionOverwrite
from .enums import ChannelType, try_enum from .enums import ChannelType, try_enum
from collections import namedtuple from collections import namedtuple
@ -33,82 +32,54 @@ from .role import Role
from .user import User from .user import User
from .member import Member from .member import Member
import copy
import asyncio
__all__ = ('TextChannel', 'VoiceChannel', 'DMChannel', 'GroupChannel', '_channel_factory')
Overwrites = namedtuple('Overwrites', 'id allow deny type') Overwrites = namedtuple('Overwrites', 'id allow deny type')
class Channel(Hashable): class CommonGuildChannel(Hashable):
"""Represents a Discord server channel. __slots__ = ()
Supported Operations: def __str__(self):
return self.name
+-----------+---------------------------------------+ @asyncio.coroutine
| Operation | Description | def _move(self, position):
+===========+=======================================+ if position < 0:
| x == y | Checks if two channels are equal. | raise InvalidArgument('Channel position cannot be less than 0.')
+-----------+---------------------------------------+
| x != y | Checks if two channels are not equal. |
+-----------+---------------------------------------+
| hash(x) | Returns the channel's hash. |
+-----------+---------------------------------------+
| str(x) | Returns the channel's name. |
+-----------+---------------------------------------+
Attributes http = self._state.http
----------- url = '{0}/{1.server.id}/channels'.format(http.GUILDS, self)
name: str channels = [c for c in self.server.channels if isinstance(c, type(self))]
The channel name.
server: :class:`Server`
The server the channel belongs to.
id: int
The channel ID.
topic: Optional[str]
The channel's topic. None if it doesn't exist.
is_private: bool
``True`` if the channel is a private channel (i.e. PM). ``False`` in this case.
position: int
The position in the channel list. This is a number that starts at 0. e.g. the
top channel is position 0. The position varies depending on being a voice channel
or a text channel, so a 0 position voice channel is on top of the voice channel
list.
type: :class:`ChannelType`
The channel type. There is a chance that the type will be ``str`` if
the channel type is not within the ones recognised by the enumerator.
bitrate: int
The channel's preferred audio bitrate in bits per second.
voice_members
A list of :class:`Members` that are currently inside this voice channel.
If :attr:`type` is not :attr:`ChannelType.voice` then this is always an empty array.
user_limit: int
The channel's limit for number of members that can be in a voice channel.
"""
__slots__ = ( 'voice_members', 'name', 'id', 'server', 'topic', if position >= len(channels):
'type', 'bitrate', 'user_limit', '_state', 'position', raise InvalidArgument('Channel position cannot be greater than {}'.format(len(channels) - 1))
'_permission_overwrites' )
def __init__(self, *, state, server, data): channels.sort(key=lambda c: c.position)
self._state = state
self.id = int(data['id'])
self._update(server, data)
self.voice_members = []
def __str__(self): try:
return self.name # remove ourselves from the channel list
channels.remove(self)
except ValueError:
# not there somehow lol
return
else:
# add ourselves at our designated position
channels.insert(position, self)
def _update(self, server, data): payload = [{'id': c.id, 'position': index } for index, c in enumerate(channels)]
self.server = server yield from http.patch(url, json=payload, bucket='move_channel')
self.name = data['name']
self.topic = data.get('topic') def _fill_overwrites(self, data):
self.position = data['position'] self._overwrites = []
self.bitrate = data.get('bitrate')
self.type = data['type']
self.user_limit = data.get('user_limit')
self._permission_overwrites = []
everyone_index = 0 everyone_index = 0
everyone_id = self.server.id everyone_id = self.server.id
for index, overridden in enumerate(data.get('permission_overwrites', [])): for index, overridden in enumerate(data.get('permission_overwrites', [])):
overridden_id = int(overridden.pop('id')) overridden_id = int(overridden.pop('id'))
self._permission_overwrites.append(Overwrites(id=overridden_id, **overridden)) self._overwrites.append(Overwrites(id=overridden_id, **overridden))
if overridden['type'] == 'member': if overridden['type'] == 'member':
continue continue
@ -122,7 +93,7 @@ class Channel(Hashable):
everyone_index = index everyone_index = index
# do the swap # do the swap
tmp = self._permission_overwrites tmp = self._overwrites
if tmp: if tmp:
tmp[everyone_index], tmp[0] = tmp[0], tmp[everyone_index] tmp[everyone_index], tmp[0] = tmp[0], tmp[everyone_index]
@ -131,7 +102,7 @@ class Channel(Hashable):
"""Returns a list of :class:`Roles` that have been overridden from """Returns a list of :class:`Roles` that have been overridden from
their default values in the :attr:`Server.roles` attribute.""" their default values in the :attr:`Server.roles` attribute."""
ret = [] ret = []
for overwrite in filter(lambda o: o.type == 'role', self._permission_overwrites): for overwrite in filter(lambda o: o.type == 'role', self._overwrites):
role = utils.get(self.server.roles, id=overwrite.id) role = utils.get(self.server.roles, id=overwrite.id)
if role is None: if role is None:
continue continue
@ -146,10 +117,6 @@ class Channel(Hashable):
"""bool : Indicates if this is the default channel for the :class:`Server` it belongs to.""" """bool : Indicates if this is the default channel for the :class:`Server` it belongs to."""
return self.server.id == self.id return self.server.id == self.id
@property
def is_private(self):
return False
@property @property
def mention(self): def mention(self):
"""str : The string that allows you to mention the channel.""" """str : The string that allows you to mention the channel."""
@ -182,7 +149,7 @@ class Channel(Hashable):
else: else:
predicate = lambda p: True predicate = lambda p: True
for overwrite in filter(predicate, self._permission_overwrites): for overwrite in filter(predicate, self._overwrites):
if overwrite.id == obj.id: if overwrite.id == obj.id:
allow = Permissions(overwrite.allow) allow = Permissions(overwrite.allow)
deny = Permissions(overwrite.deny) deny = Permissions(overwrite.deny)
@ -276,7 +243,7 @@ class Channel(Hashable):
allows = 0 allows = 0
# Apply channel specific role permission overwrites # Apply channel specific role permission overwrites
for overwrite in self._permission_overwrites: for overwrite in self._overwrites:
if overwrite.type == 'role' and overwrite.id in member_role_ids: if overwrite.type == 'role' and overwrite.id in member_role_ids:
denies |= overwrite.deny denies |= overwrite.deny
allows |= overwrite.allow allows |= overwrite.allow
@ -284,7 +251,7 @@ class Channel(Hashable):
base.handle_overwrite(allow=allows, deny=denies) base.handle_overwrite(allow=allows, deny=denies)
# Apply member specific permission overwrites # Apply member specific permission overwrites
for overwrite in self._permission_overwrites: for overwrite in self._overwrites:
if overwrite.type == 'member' and overwrite.id == member.id: if overwrite.type == 'member' and overwrite.id == member.id:
base.handle_overwrite(allow=overwrite.allow, deny=overwrite.deny) base.handle_overwrite(allow=overwrite.allow, deny=overwrite.deny)
break break
@ -307,14 +274,286 @@ class Channel(Hashable):
base.value &= ~denied.value base.value &= ~denied.value
# text channels do not have voice related permissions # text channels do not have voice related permissions
if self.type is ChannelType.text: if isinstance(self, TextChannel):
denied = Permissions.voice() denied = Permissions.voice()
base.value &= ~denied.value base.value &= ~denied.value
return base return base
class PrivateChannel(Hashable): @asyncio.coroutine
"""Represents a Discord private channel. def delete(self):
"""|coro|
Deletes the channel.
You must have Manage Channel permission to use this.
Raises
-------
Forbidden
You do not have proper permissions to delete the channel.
NotFound
The channel was not found or was already deleted.
HTTPException
Deleting the channel failed.
"""
yield from self._state.http.delete_channel(self.id)
class TextChannel(abc.MessageChannel, CommonGuildChannel):
"""Represents a Discord server text channel.
Supported Operations:
+-----------+---------------------------------------+
| Operation | Description |
+===========+=======================================+
| x == y | Checks if two channels are equal. |
+-----------+---------------------------------------+
| x != y | Checks if two channels are not equal. |
+-----------+---------------------------------------+
| hash(x) | Returns the channel's hash. |
+-----------+---------------------------------------+
| str(x) | Returns the channel's name. |
+-----------+---------------------------------------+
Attributes
-----------
name: str
The channel name.
server: :class:`Server`
The server the channel belongs to.
id: int
The channel ID.
topic: Optional[str]
The channel's topic. None if it doesn't exist.
position: int
The position in the channel list. This is a number that starts at 0. e.g. the
top channel is position 0.
"""
__slots__ = ( 'name', 'id', 'server', 'topic', '_state',
'position', '_overwrites' )
def __init__(self, *, state, server, data):
self._state = state
self.id = int(data['id'])
self._update(server, data)
def _update(self, server, data):
self.server = server
self.name = data['name']
self.topic = data.get('topic')
self.position = data['position']
self._fill_overwrites(data)
def _get_destination(self):
return self.id, self.server.id
@asyncio.coroutine
def edit(self, **options):
"""|coro|
Edits the channel.
You must have the Manage Channel permission to use this.
Parameters
----------
name: str
The new channel name.
topic: str
The new channel's topic.
position: int
The new channel's position.
Raises
------
InvalidArgument
If position is less than 0 or greater than the number of channels.
Forbidden
You do not have permissions to edit the channel.
HTTPException
Editing the channel failed.
"""
try:
position = options.pop('position')
except KeyError:
pass
else:
yield from self._move(position)
self.position = position
if options:
data = yield from self._state.http.edit_channel(self.id, **options)
self._update(self.server, data)
class VoiceChannel(CommonGuildChannel):
"""Represents a Discord server voice channel.
Supported Operations:
+-----------+---------------------------------------+
| Operation | Description |
+===========+=======================================+
| x == y | Checks if two channels are equal. |
+-----------+---------------------------------------+
| x != y | Checks if two channels are not equal. |
+-----------+---------------------------------------+
| hash(x) | Returns the channel's hash. |
+-----------+---------------------------------------+
| str(x) | Returns the channel's name. |
+-----------+---------------------------------------+
Attributes
-----------
name: str
The channel name.
server: :class:`Server`
The server the channel belongs to.
id: int
The channel ID.
position: int
The position in the channel list. This is a number that starts at 0. e.g. the
top channel is position 0.
bitrate: int
The channel's preferred audio bitrate in bits per second.
voice_members
A list of :class:`Members` that are currently inside this voice channel.
user_limit: int
The channel's limit for number of members that can be in a voice channel.
"""
__slots__ = ( 'voice_members', 'name', 'id', 'server', 'bitrate',
'user_limit', '_state', 'position', '_overwrites' )
def __init__(self, *, state, server, data):
self._state = state
self.id = int(data['id'])
self._update(server, data)
self.voice_members = []
def _update(self, server, data):
self.server = server
self.name = data['name']
self.position = data['position']
self.bitrate = data.get('bitrate')
self.user_limit = data.get('user_limit')
self._fill_overwrites(data)
@asyncio.coroutine
def edit(self, **options):
"""|coro|
Edits the channel.
You must have the Manage Channel permission to use this.
Parameters
----------
bitrate: int
The new channel's bitrate.
user_limit: int
The new channel's user limit.
position: int
The new channel's position.
Raises
------
Forbidden
You do not have permissions to edit the channel.
HTTPException
Editing the channel failed.
"""
try:
position = options.pop('position')
except KeyError:
pass
else:
yield from self._move(position)
self.position = position
if options:
data = yield from self._state.http.edit_channel(self.id, **options)
self._update(self.server, data)
class DMChannel(abc.MessageChannel, Hashable):
"""Represents a Discord direct message channel.
Supported Operations:
+-----------+-------------------------------------------------+
| Operation | Description |
+===========+=================================================+
| x == y | Checks if two channels are equal. |
+-----------+-------------------------------------------------+
| x != y | Checks if two channels are not equal. |
+-----------+-------------------------------------------------+
| hash(x) | Returns the channel's hash. |
+-----------+-------------------------------------------------+
| str(x) | Returns a string representation of the channel |
+-----------+-------------------------------------------------+
Attributes
----------
recipient: :class:`User`
The user you are participating with in the direct message channel.
me: :class:`User`
The user presenting yourself.
id: int
The direct message channel ID.
"""
__slots__ = ('id', 'recipient', 'me', '_state')
def __init__(self, *, me, state, data):
self._state = state
self.recipient = state.try_insert_user(data['recipients'][0])
self.me = me
self.id = int(data['id'])
def _get_destination(self):
return self.id, None
def __str__(self):
return 'Direct Message with %s' % self.recipient
@property
def created_at(self):
"""Returns the direct message channel's creation time in UTC."""
return utils.snowflake_time(self.id)
def permissions_for(self, user=None):
"""Handles permission resolution for a :class:`User`.
This function is there for compatibility with other channel types.
Actual direct messages do not really have the concept of permissions.
This returns all the Text related permissions set to true except:
- send_tts_messages: You cannot send TTS messages in a DM.
- manage_messages: You cannot delete others messages in a DM.
Parameters
-----------
user: :class:`User`
The user to check permissions for. This parameter is ignored
but kept for compatibility.
Returns
--------
:class:`Permissions`
The resolved permissions.
"""
base = Permissions.text()
base.send_tts_messages = False
base.manage_messages = False
return base
class GroupChannel(abc.MessageChannel, Hashable):
"""Represents a Discord group channel.
Supported Operations: Supported Operations:
@ -333,50 +572,42 @@ class PrivateChannel(Hashable):
Attributes Attributes
---------- ----------
recipients: list of :class:`User` recipients: list of :class:`User`
The users you are participating with in the private channel. The users you are participating with in the group channel.
me: :class:`User` me: :class:`User`
The user presenting yourself. The user presenting yourself.
id: int id: int
The private channel ID. The group channel ID.
is_private: bool owner: :class:`User`
``True`` if the channel is a private channel (i.e. PM). ``True`` in this case. The user that owns the group channel.
type: :class:`ChannelType`
The type of private channel.
owner: Optional[:class:`User`]
The user that owns the private channel. If the channel type is not
:attr:`ChannelType.group` then this is always ``None``.
icon: Optional[str] icon: Optional[str]
The private channel's icon hash. If the channel type is not The group channel's icon hash if provided.
:attr:`ChannelType.group` then this is always ``None``.
name: Optional[str] name: Optional[str]
The private channel's name. If the channel type is not The group channel's name if provided.
:attr:`ChannelType.group` then this is always ``None``.
""" """
__slots__ = ('id', 'recipients', 'type', 'owner', 'icon', 'name', 'me', '_state') __slots__ = ('id', 'recipients', 'owner', 'icon', 'name', 'me', '_state')
def __init__(self, *, me, state, data): def __init__(self, *, me, state, data):
self._state = state self._state = state
self.recipients = [state.try_insert_user(u) for u in data['recipients']] self.recipients = [state.try_insert_user(u) for u in data['recipients']]
self.id = int(data['id']) self.id = int(data['id'])
self.me = me self.me = me
self.type = try_enum(ChannelType, data['type'])
self._update_group(data) self._update_group(data)
def _update_group(self, data): def _update_group(self, data):
owner_id = utils._get_as_snowflake(data, 'owner_id') owner_id = utils._get_as_snowflake(data, 'owner_id')
self.icon = data.get('icon') self.icon = data.get('icon')
self.name = data.get('name') self.name = data.get('name')
self.owner = utils.find(lambda u: u.id == owner_id, self.recipients)
@property if owner_id == self.me.id:
def is_private(self): self.owner = self.me
return True else:
self.owner = utils.find(lambda u: u.id == owner_id, self.recipients)
def __str__(self): def _get_destination(self):
if self.type is ChannelType.private: return self.id, None
return 'Direct Message with {0.name}'.format(self.user)
def __str__(self):
if self.name: if self.name:
return self.name return self.name
@ -385,15 +616,6 @@ class PrivateChannel(Hashable):
return ', '.join(map(lambda x: x.name, self.recipients)) return ', '.join(map(lambda x: x.name, self.recipients))
@property
def user(self):
"""A property that returns the first recipient of the private channel.
This is mainly for compatibility and ease of use with old style private
channels that had a single recipient.
"""
return self.recipients[0]
@property @property
def icon_url(self): def icon_url(self):
"""Returns the channel's icon URL if available or an empty string otherwise.""" """Returns the channel's icon URL if available or an empty string otherwise."""
@ -404,27 +626,26 @@ class PrivateChannel(Hashable):
@property @property
def created_at(self): def created_at(self):
"""Returns the private channel's creation time in UTC.""" """Returns the channel's creation time in UTC."""
return utils.snowflake_time(self.id) return utils.snowflake_time(self.id)
def permissions_for(self, user): def permissions_for(self, user):
"""Handles permission resolution for a :class:`User`. """Handles permission resolution for a :class:`User`.
This function is there for compatibility with :class:`Channel`. This function is there for compatibility with other channel types.
Actual private messages do not really have the concept of permissions. Actual direct messages do not really have the concept of permissions.
This returns all the Text related permissions set to true except: This returns all the Text related permissions set to true except:
- send_tts_messages: You cannot send TTS messages in a PM. - send_tts_messages: You cannot send TTS messages in a DM.
- manage_messages: You cannot delete others messages in a PM. - manage_messages: You cannot delete others messages in a DM.
This also handles permissions for :attr:`ChannelType.group` channels This also checks the kick_members permission if the user is the owner.
such as kicking or mentioning everyone.
Parameters Parameters
----------- -----------
user : :class:`User` user: :class:`User`
The user to check permissions for. The user to check permissions for.
Returns Returns
@ -436,11 +657,22 @@ class PrivateChannel(Hashable):
base = Permissions.text() base = Permissions.text()
base.send_tts_messages = False base.send_tts_messages = False
base.manage_messages = False base.manage_messages = False
base.mention_everyone = self.type is ChannelType.group base.mention_everyone = True
if user == self.owner: if user.id == self.owner.id:
base.kick_members = True base.kick_members = True
return base return base
def _channel_factory(channel_type):
value = try_enum(ChannelType, channel_type)
if value is ChannelType.text:
return TextChannel, value
elif value is ChannelType.voice:
return VoiceChannel, value
elif value is ChannelType.private:
return DMChannel, value
elif value is ChannelType.group:
return GroupChannel, value
else:
return None, value

8
discord/client.py

@ -27,7 +27,7 @@ DEALINGS IN THE SOFTWARE.
from . import __version__ as library_version from . import __version__ as library_version
from .user import User from .user import User
from .member import Member from .member import Member
from .channel import Channel, PrivateChannel from .channel import *
from .server import Server from .server import Server
from .message import Message from .message import Message
from .invite import Invite from .invite import Invite
@ -261,9 +261,9 @@ class Client:
@asyncio.coroutine @asyncio.coroutine
def _resolve_destination(self, destination): def _resolve_destination(self, destination):
if isinstance(destination, Channel): if isinstance(destination, TextChannel):
return destination.id, destination.server.id return destination.id, destination.server.id
elif isinstance(destination, PrivateChannel): elif isinstance(destination, DMChannel):
return destination.id, None return destination.id, None
elif isinstance(destination, Server): elif isinstance(destination, Server):
return destination.id, destination.id return destination.id, destination.id
@ -283,7 +283,7 @@ class Client:
# couldn't find it in cache so YOLO # couldn't find it in cache so YOLO
return destination.id, destination.id return destination.id, destination.id
else: else:
fmt = 'Destination must be Channel, PrivateChannel, User, or Object. Received {0.__class__.__name__}' fmt = 'Destination must be TextChannel, DMChannel, User, or Object. Received {0.__class__.__name__}'
raise InvalidArgument(fmt.format(destination)) raise InvalidArgument(fmt.format(destination))
def __getattr__(self, name): def __getattr__(self, name):

6
discord/errors.py

@ -38,6 +38,12 @@ class ClientException(DiscordException):
""" """
pass pass
class NoMoreMessages(DiscordException):
"""Exception that is thrown when a ``history`` operation has no more
messages. This is only exposed for Python 3.4 only.
"""
pass
class GatewayNotFound(DiscordException): class GatewayNotFound(DiscordException):
"""An exception that is usually thrown when the gateway hub """An exception that is usually thrown when the gateway hub
for the :class:`Client` websocket is not found.""" for the :class:`Client` websocket is not found."""

70
discord/iterators.py

@ -27,23 +27,26 @@ DEALINGS IN THE SOFTWARE.
import sys import sys
import asyncio import asyncio
import aiohttp import aiohttp
import datetime
from .errors import NoMoreMessages
from .utils import time_snowflake
from .message import Message from .message import Message
from .object import Object from .object import Object
PY35 = sys.version_info >= (3, 5) PY35 = sys.version_info >= (3, 5)
class LogsFromIterator: class LogsFromIterator:
"""Iterator for recieving logs. """Iterator for receiving logs.
The messages endpoint has two behaviors we care about here: The messages endpoint has two behaviours we care about here:
If `before` is specified, the messages endpoint returns the `limit` If `before` is specified, the messages endpoint returns the `limit`
newest messages before `before`, sorted with newest first. For filling over newest messages before `before`, sorted with newest first. For filling over
100 messages, update the `before` parameter to the oldest message recieved. 100 messages, update the `before` parameter to the oldest message received.
Messages will be returned in order by time. Messages will be returned in order by time.
If `after` is specified, it returns the `limit` oldest messages after If `after` is specified, it returns the `limit` oldest messages after
`after`, sorted with newest first. For filling over 100 messages, update the `after`, sorted with newest first. For filling over 100 messages, update the
`after` parameter to the newest message recieved. If messages are not `after` parameter to the newest message received. If messages are not
reversed, they will be out of order (99-0, 199-100, so on) reversed, they will be out of order (99-0, 199-100, so on)
A note that if both before and after are specified, before is ignored by the A note that if both before and after are specified, before is ignored by the
@ -51,8 +54,7 @@ class LogsFromIterator:
Parameters Parameters
----------- -----------
client : class:`Client` channel: class:`Channel`
channel : class:`Channel`
Channel from which to request logs Channel from which to request logs
limit : int limit : int
Maximum number of messages to retrieve Maximum number of messages to retrieve
@ -63,24 +65,37 @@ class LogsFromIterator:
around : :class:`Message` or id-like around : :class:`Message` or id-like
Message around which all messages must be. Limit max 101. Note that if Message around which all messages must be. Limit max 101. Note that if
limit is an even number, this will return at most limit+1 messages. limit is an even number, this will return at most limit+1 messages.
reverse : bool reverse: bool
If set to true, return messages in oldest->newest order. Recommended If set to true, return messages in oldest->newest order. Recommended
when using with "after" queries with limit over 100, otherwise messages when using with "after" queries with limit over 100, otherwise messages
will be out of order. Defaults to False for backwards compatability. will be out of order.
""" """
def __init__(self, client, channel, limit, def __init__(self, channel, limit,
before=None, after=None, around=None, reverse=False): before=None, after=None, around=None, reverse=None):
self.client = client
if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False))
if isinstance(after, datetime.datetime):
after = Object(id=time_snowflake(after, high=True))
if isinstance(around, datetime.datetime):
around = Object(id=time_snowflake(around))
self.channel = channel self.channel = channel
self.ctx = channel._state
self.logs_from = channel._state.http.logs_from
self.limit = limit self.limit = limit
self.before = before self.before = before
self.after = after self.after = after
self.around = around self.around = around
self.reverse = reverse
if reverse is None:
self.reverse = after is not None
else:
self.reverse = reverse
self._filter = None # message dict -> bool self._filter = None # message dict -> bool
self.messages = asyncio.Queue() self.messages = asyncio.Queue()
self.ctx = client.connection.ctx
if self.around: if self.around:
if self.limit > 101: if self.limit > 101:
@ -92,29 +107,32 @@ class LogsFromIterator:
self._retrieve_messages = self._retrieve_messages_around_strategy self._retrieve_messages = self._retrieve_messages_around_strategy
if self.before and self.after: if self.before and self.after:
self._filter = lambda m: self.after.id < m['id'] < self.before.id self._filter = lambda m: self.after.id < int(m['id']) < self.before.id
elif self.before: elif self.before:
self._filter = lambda m: m['id'] < self.before.id self._filter = lambda m: int(m['id']) < self.before.id
elif self.after: elif self.after:
self._filter = lambda m: self.after.id < m['id'] self._filter = lambda m: self.after.id < int(m['id'])
elif self.before and self.after: elif self.before and self.after:
if self.reverse: if self.reverse:
self._retrieve_messages = self._retrieve_messages_after_strategy self._retrieve_messages = self._retrieve_messages_after_strategy
self._filter = lambda m: m['id'] < self.before.id self._filter = lambda m: int(m['id']) < self.before.id
else: else:
self._retrieve_messages = self._retrieve_messages_before_strategy self._retrieve_messages = self._retrieve_messages_before_strategy
self._filter = lambda m: m['id'] > self.after.id self._filter = lambda m: int(m['id']) > self.after.id
elif self.after: elif self.after:
self._retrieve_messages = self._retrieve_messages_after_strategy self._retrieve_messages = self._retrieve_messages_after_strategy
else: else:
self._retrieve_messages = self._retrieve_messages_before_strategy self._retrieve_messages = self._retrieve_messages_before_strategy
@asyncio.coroutine @asyncio.coroutine
def iterate(self): def get(self):
if self.messages.empty(): if self.messages.empty():
yield from self.fill_messages() yield from self.fill_messages()
return self.messages.get_nowait() try:
return self.messages.get_nowait()
except asyncio.QueueEmpty:
raise NoMoreMessages()
@asyncio.coroutine @asyncio.coroutine
def fill_messages(self): def fill_messages(self):
@ -136,7 +154,7 @@ class LogsFromIterator:
@asyncio.coroutine @asyncio.coroutine
def _retrieve_messages_before_strategy(self, retrieve): def _retrieve_messages_before_strategy(self, retrieve):
"""Retrieve messages using before parameter.""" """Retrieve messages using before parameter."""
data = yield from self.client._logs_from(self.channel, retrieve, before=self.before) data = yield from self.logs_from(self.channel.id, retrieve, before=getattr(self.before, 'id', None))
if len(data): if len(data):
self.limit -= retrieve self.limit -= retrieve
self.before = Object(id=int(data[-1]['id'])) self.before = Object(id=int(data[-1]['id']))
@ -145,7 +163,7 @@ class LogsFromIterator:
@asyncio.coroutine @asyncio.coroutine
def _retrieve_messages_after_strategy(self, retrieve): def _retrieve_messages_after_strategy(self, retrieve):
"""Retrieve messages using after parameter.""" """Retrieve messages using after parameter."""
data = yield from self.client._logs_from(self.channel, retrieve, after=self.after) data = yield from self.logs_from(self.channel.id, retrieve, after=getattr(self.after, 'id', None))
if len(data): if len(data):
self.limit -= retrieve self.limit -= retrieve
self.after = Object(id=int(data[0]['id'])) self.after = Object(id=int(data[0]['id']))
@ -155,7 +173,7 @@ class LogsFromIterator:
def _retrieve_messages_around_strategy(self, retrieve): def _retrieve_messages_around_strategy(self, retrieve):
"""Retrieve messages using around parameter.""" """Retrieve messages using around parameter."""
if self.around: if self.around:
data = yield from self.client._logs_from(self.channel, retrieve, around=self.around) data = yield from self.logs_from(self.channel.id, retrieve, around=getattr(self.around, 'id', None))
self.around = None self.around = None
return data return data
return [] return []
@ -168,9 +186,9 @@ class LogsFromIterator:
@asyncio.coroutine @asyncio.coroutine
def __anext__(self): def __anext__(self):
try: try:
msg = yield from self.iterate() msg = yield from self.get()
return msg return msg
except asyncio.QueueEmpty: except NoMoreMessages:
# if we're still empty at this point... # if we're still empty at this point...
# we didn't get any new messages so stop looping # we didn't get any new messages so stop looping
raise StopAsyncIteration() raise StopAsyncIteration()

4
discord/message.py

@ -24,9 +24,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from . import utils
from .user import User from .user import User
from .reaction import Reaction from .reaction import Reaction
from . import utils, abc
from .object import Object from .object import Object
from .calls import CallMessage from .calls import CallMessage
import re import re
@ -292,7 +292,7 @@ class Message:
self.channel.is_private = True self.channel.is_private = True
return return
if not self.channel.is_private: if isinstance(self.channel, abc.GuildChannel):
self.server = self.channel.server self.server = self.channel.server
found = self.server.get_member(self.author.id) found = self.server.get_member(self.author.id)
if found is not None: if found is not None:

10
discord/server.py

@ -29,8 +29,8 @@ from .role import Role
from .member import Member, VoiceState from .member import Member, VoiceState
from .emoji import Emoji from .emoji import Emoji
from .game import Game from .game import Game
from .channel import Channel from .channel import *
from .enums import ServerRegion, Status, try_enum, VerificationLevel from .enums import ServerRegion, Status, ChannelType, try_enum, VerificationLevel
from .mixins import Hashable from .mixins import Hashable
import copy import copy
@ -273,7 +273,11 @@ class Server(Hashable):
if 'channels' in data: if 'channels' in data:
channels = data['channels'] channels = data['channels']
for c in channels: for c in channels:
channel = Channel(server=self, data=c, state=self._state) if c['type'] == ChannelType.text.value:
channel = TextChannel(server=self, data=c, state=self._state)
else:
channel = VoiceChannel(server=self, data=c, state=self._state)
self._add_channel(channel) self._add_channel(channel)
@utils.cached_slot_property('_default_role') @utils.cached_slot_property('_default_role')

40
discord/state.py

@ -30,7 +30,7 @@ from .game import Game
from .emoji import Emoji from .emoji import Emoji
from .reaction import Reaction from .reaction import Reaction
from .message import Message from .message import Message
from .channel import Channel, PrivateChannel from .channel import *
from .member import Member from .member import Member
from .role import Role from .role import Role
from . import utils, compat from . import utils, compat
@ -153,13 +153,13 @@ class ConnectionState:
def _add_private_channel(self, channel): def _add_private_channel(self, channel):
self._private_channels[channel.id] = channel self._private_channels[channel.id] = channel
if channel.type is ChannelType.private: if isinstance(channel, DMChannel):
self._private_channels_by_user[channel.user.id] = channel self._private_channels_by_user[channel.recipient.id] = channel
def _remove_private_channel(self, channel): def _remove_private_channel(self, channel):
self._private_channels.pop(channel.id, None) self._private_channels.pop(channel.id, None)
if channel.type is ChannelType.private: if isinstance(channel, DMChannel):
self._private_channels_by_user.pop(channel.user.id, None) self._private_channels_by_user.pop(channel.recipient.id, None)
def _get_message(self, msg_id): def _get_message(self, msg_id):
return utils.find(lambda m: m.id == msg_id, self.messages) return utils.find(lambda m: m.id == msg_id, self.messages)
@ -229,7 +229,8 @@ class ConnectionState:
servers.append(server) servers.append(server)
for pm in data.get('private_channels'): for pm in data.get('private_channels'):
self._add_private_channel(PrivateChannel(me=self.user, data=pm, state=self.ctx)) factory, _ = _channel_factory(pm['type'])
self._add_private_channel(factory(me=self.user, data=pm, state=self.ctx))
compat.create_task(self._delay_ready(), loop=self.loop) compat.create_task(self._delay_ready(), loop=self.loop)
@ -348,13 +349,18 @@ class ConnectionState:
self.user = User(state=self.ctx, data=data) self.user = User(state=self.ctx, data=data)
def parse_channel_delete(self, data): def parse_channel_delete(self, data):
server = self._get_server(int(data['guild_id'])) server = self._get_server(utils._get_as_snowflake(data, 'guild_id'))
channel_id = int(data['id'])
if server is not None: if server is not None:
channel_id = data.get('id')
channel = server.get_channel(channel_id) channel = server.get_channel(channel_id)
if channel is not None: if channel is not None:
server._remove_channel(channel) server._remove_channel(channel)
self.dispatch('channel_delete', channel) self.dispatch('channel_delete', channel)
else:
# the reason we're doing this is so it's also removed from the
# private channel by user cache as well
channel = self._get_private_channel(channel_id)
self._remove_private_channel(channel)
def parse_channel_update(self, data): def parse_channel_update(self, data):
channel_type = try_enum(ChannelType, data.get('type')) channel_type = try_enum(ChannelType, data.get('type'))
@ -375,15 +381,15 @@ class ConnectionState:
self.dispatch('channel_update', old_channel, channel) self.dispatch('channel_update', old_channel, channel)
def parse_channel_create(self, data): def parse_channel_create(self, data):
ch_type = try_enum(ChannelType, data.get('type')) factory, ch_type = _channel_factory(data['type'])
channel = None channel = None
if ch_type in (ChannelType.group, ChannelType.private): if ch_type in (ChannelType.group, ChannelType.private):
channel = PrivateChannel(me=self.user, data=data, state=self.ctx) channel = factory(me=self.user, data=data, state=self.ctx)
self._add_private_channel(channel) self._add_private_channel(channel)
else: else:
server = self._get_server(utils._get_as_snowflake(data, 'guild_id')) server = self._get_server(utils._get_as_snowflake(data, 'guild_id'))
if server is not None: if server is not None:
channel = Channel(server=server, state=self.ctx, data=data) channel = factory(server=server, state=self.ctx, data=data)
server._add_channel(channel) server._add_channel(channel)
self.dispatch('channel_create', channel) self.dispatch('channel_create', channel)
@ -638,14 +644,12 @@ class ConnectionState:
if channel is not None: if channel is not None:
member = None member = None
user_id = utils._get_as_snowflake(data, 'user_id') user_id = utils._get_as_snowflake(data, 'user_id')
is_private = getattr(channel, 'is_private', None) if isinstance(channel, DMChannel):
if is_private == None: member = channel.recipient
return elif isinstance(channel, TextChannel):
if is_private:
member = channel.user
else:
member = channel.server.get_member(user_id) member = channel.server.get_member(user_id)
elif isinstance(channel, GroupChannel):
member = utils.find(lambda x: x.id == user_id, channel.recipients)
if member is not None: if member is not None:
timestamp = datetime.datetime.utcfromtimestamp(data.get('timestamp')) timestamp = datetime.datetime.utcfromtimestamp(data.get('timestamp'))

Loading…
Cancel
Save