Browse Source

Split channel types.

This splits them into the following:

* DMChannel
* GroupChannel
* VoiceChannel
* TextChannel

This also makes the channels "stateful".
pull/447/head
Rapptz 9 years ago
parent
commit
53ab263125
  1. 2
      discord/__init__.py
  2. 277
      discord/abc.py
  3. 4
      discord/calls.py
  4. 464
      discord/channel.py
  5. 8
      discord/client.py
  6. 6
      discord/errors.py
  7. 62
      discord/iterators.py
  8. 4
      discord/message.py
  9. 10
      discord/server.py
  10. 40
      discord/state.py

2
discord/__init__.py

@ -21,7 +21,7 @@ from .client import Client, AppInfo, ChannelPermissions
from .user import User
from .game import Game
from .emoji import Emoji
from .channel import Channel, PrivateChannel
from .channel import *
from .server import Server
from .member import Member, VoiceState
from .message import Message

277
discord/abc.py

@ -25,6 +25,12 @@ DEALINGS IN THE SOFTWARE.
"""
import abc
import io
import os
import asyncio
from .message import Message
from .iterators import LogsFromIterator
class Snowflake(metaclass=abc.ABCMeta):
__slots__ = ()
@ -75,3 +81,274 @@ class User(metaclass=abc.ABCMeta):
return NotImplemented
return True
return NotImplemented
class GuildChannel(metaclass=abc.ABCMeta):
__slots__ = ()
@property
@abc.abstractmethod
def mention(self):
raise NotImplementedError
@abc.abstractmethod
def overwrites_for(self, obj):
raise NotImplementedError
@abc.abstractmethod
def permissions_for(self, user):
raise NotImplementedError
@classmethod
def __subclasshook__(cls, C):
if cls is GuildChannel:
if Snowflake.__subclasshook__(C) is NotImplemented:
return NotImplemented
mro = C.__mro__
for attr in ('name', 'server', 'overwrites_for', 'permissions_for', 'mention'):
for base in mro:
if attr in base.__dict__:
break
else:
return NotImplemented
return True
return NotImplemented
class PrivateChannel(metaclass=abc.ABCMeta):
__slots__ = ()
@classmethod
def __subclasshook__(cls, C):
if cls is PrivateChannel:
if Snowflake.__subclasshook__(C) is NotImplemented:
return NotImplemented
mro = C.__mro__
for base in mro:
if 'me' in base.__dict__:
return True
return NotImplemented
return NotImplemented
class MessageChannel(metaclass=abc.ABCMeta):
__slots__ = ()
@abc.abstractmethod
def _get_destination(self):
raise NotImplementedError
@asyncio.coroutine
def send_message(self, content, *, tts=False):
"""|coro|
Sends a message to the channel with the content given.
The content must be a type that can convert to a string through ``str(content)``.
Parameters
------------
content
The content of the message to send.
tts: bool
Indicates if the message should be sent using text-to-speech.
Raises
--------
HTTPException
Sending the message failed.
Forbidden
You do not have the proper permissions to send the message.
Returns
---------
:class:`Message`
The message that was sent.
"""
channel_id, guild_id = self._get_destination()
content = str(content)
data = yield from self._state.http.send_message(channel_id, content, guild_id=guild_id, tts=tts)
return Message(channel=self, state=self._state, data=data)
@asyncio.coroutine
def send_typing(self):
"""|coro|
Send a *typing* status to the channel.
*Typing* status will go away after 10 seconds, or after a message is sent.
"""
channel_id, _ = self._get_destination()
yield from self._state.http.send_typing(channel_id)
@asyncio.coroutine
def upload(self, fp, *, filename=None, content=None, tts=False):
"""|coro|
Sends a message to the channel with the file given.
The ``fp`` parameter should be either a string denoting the location for a
file or a *file-like object*. The *file-like object* passed is **not closed**
at the end of execution. You are responsible for closing it yourself.
.. note::
If the file-like object passed is opened via ``open`` then the modes
'rb' should be used.
The ``filename`` parameter is the filename of the file.
If this is not given then it defaults to ``fp.name`` or if ``fp`` is a string
then the ``filename`` will default to the string given. You can overwrite
this value by passing this in.
Parameters
------------
fp
The *file-like object* or file path to send.
filename: str
The filename of the file. Defaults to ``fp.name`` if it's available.
content: str
The content of the message to send along with the file. This is
forced into a string by a ``str(content)`` call.
tts: bool
If the content of the message should be sent with TTS enabled.
Raises
-------
HTTPException
Sending the file failed.
Returns
--------
:class:`Message`
The message sent.
"""
channel_id, guild_id = self._get_destination()
try:
with open(fp, 'rb') as f:
buffer = io.BytesIO(f.read())
if filename is None:
_, filename = os.path.split(fp)
except TypeError:
buffer = fp
state = self._state
data = yield from state.http.send_file(channel_id, buffer, guild_id=guild_id,
filename=filename, content=content, tts=tts)
return Message(channel=self, state=state, data=data)
@asyncio.coroutine
def get_message(self, id):
"""|coro|
Retrieves a single :class:`Message` from a channel.
This can only be used by bot accounts.
Parameters
------------
id: int
The message ID to look for.
Returns
--------
:class:`Message`
The message asked for.
Raises
--------
NotFound
The specified message was not found.
Forbidden
You do not have the permissions required to get a message.
HTTPException
Retrieving the message failed.
"""
data = yield from self._state.http.get_message(self.id, id)
return Message(channel=self, state=self._state, data=data)
@asyncio.coroutine
def pins(self):
"""|coro|
Returns a list of :class:`Message` that are currently pinned.
Raises
-------
HTTPException
Retrieving the pinned messages failed.
"""
state = self._state
data = yield from state.http.pins_from(self.id)
return [Message(channel=self, state=state, data=m) for m in data]
def history(self, *, limit=100, before=None, after=None, around=None, reverse=None):
"""Return an async iterator that enables receiving the channel's message history.
You must have Read Message History permissions to use this.
All parameters are optional.
Parameters
-----------
limit: int
The number of messages to retrieve.
before: :class:`Message` or `datetime`
Retrieve messages before this date or message.
If a date is provided it must be a timezone-naive datetime representing UTC time.
after: :class:`Message` or `datetime`
Retrieve messages after this date or message.
If a date is provided it must be a timezone-naive datetime representing UTC time.
around: :class:`Message` or `datetime`
Retrieve messages around this date or message.
If a date is provided it must be a timezone-naive datetime representing UTC time.
When using this argument, the maximum limit is 101. Note that if the limit is an
even number then this will return at most limit + 1 messages.
reverse: bool
If set to true, return messages in oldest->newest order. If unspecified,
this defaults to ``False`` for most cases. However if passing in a
``after`` parameter then this is set to ``True``. This avoids getting messages
out of order in the ``after`` case.
Raises
------
Forbidden
You do not have permissions to get channel message history.
HTTPException
The request to get message history failed.
Yields
-------
:class:`Message`
The message with the message data parsed.
Examples
---------
Usage ::
counter = 0
async for message in channel.history(limit=200):
if message.author == client.user:
counter += 1
Python 3.4 Usage ::
count = 0
iterator = channel.history(limit=200)
while True:
try:
message = yield from iterator.get()
except discord.NoMoreMessages:
break
else:
if message.author == client.user:
counter += 1
"""
return LogsFromIterator(self, limit=limit, before=before, after=after, around=around, reverse=reverse)

4
discord/calls.py

@ -57,7 +57,7 @@ class CallMessage:
@property
def channel(self):
""":class:`PrivateChannel`\: The private channel associated with this message."""
""":class:`GroupChannel`\: The private channel associated with this message."""
return self.message.channel
@property
@ -131,7 +131,7 @@ class GroupCall:
@property
def channel(self):
""":class:`PrivateChannel`\: Returns the channel the group call is in."""
""":class:`GroupChannel`\: Returns the channel the group call is in."""
return self.call.channel
def voice_state_for(self, user):

464
discord/channel.py

@ -23,8 +23,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
import copy
from . import utils
from . import utils, abc
from .permissions import Permissions, PermissionOverwrite
from .enums import ChannelType, try_enum
from collections import namedtuple
@ -33,82 +32,54 @@ from .role import Role
from .user import User
from .member import Member
import copy
import asyncio
__all__ = ('TextChannel', 'VoiceChannel', 'DMChannel', 'GroupChannel', '_channel_factory')
Overwrites = namedtuple('Overwrites', 'id allow deny type')
class Channel(Hashable):
"""Represents a Discord server channel.
class CommonGuildChannel(Hashable):
__slots__ = ()
Supported Operations:
def __str__(self):
return self.name
+-----------+---------------------------------------+
| Operation | Description |
+===========+=======================================+
| x == y | Checks if two channels are equal. |
+-----------+---------------------------------------+
| x != y | Checks if two channels are not equal. |
+-----------+---------------------------------------+
| hash(x) | Returns the channel's hash. |
+-----------+---------------------------------------+
| str(x) | Returns the channel's name. |
+-----------+---------------------------------------+
@asyncio.coroutine
def _move(self, position):
if position < 0:
raise InvalidArgument('Channel position cannot be less than 0.')
Attributes
-----------
name: str
The channel name.
server: :class:`Server`
The server the channel belongs to.
id: int
The channel ID.
topic: Optional[str]
The channel's topic. None if it doesn't exist.
is_private: bool
``True`` if the channel is a private channel (i.e. PM). ``False`` in this case.
position: int
The position in the channel list. This is a number that starts at 0. e.g. the
top channel is position 0. The position varies depending on being a voice channel
or a text channel, so a 0 position voice channel is on top of the voice channel
list.
type: :class:`ChannelType`
The channel type. There is a chance that the type will be ``str`` if
the channel type is not within the ones recognised by the enumerator.
bitrate: int
The channel's preferred audio bitrate in bits per second.
voice_members
A list of :class:`Members` that are currently inside this voice channel.
If :attr:`type` is not :attr:`ChannelType.voice` then this is always an empty array.
user_limit: int
The channel's limit for number of members that can be in a voice channel.
"""
http = self._state.http
url = '{0}/{1.server.id}/channels'.format(http.GUILDS, self)
channels = [c for c in self.server.channels if isinstance(c, type(self))]
__slots__ = ( 'voice_members', 'name', 'id', 'server', 'topic',
'type', 'bitrate', 'user_limit', '_state', 'position',
'_permission_overwrites' )
if position >= len(channels):
raise InvalidArgument('Channel position cannot be greater than {}'.format(len(channels) - 1))
def __init__(self, *, state, server, data):
self._state = state
self.id = int(data['id'])
self._update(server, data)
self.voice_members = []
channels.sort(key=lambda c: c.position)
def __str__(self):
return self.name
try:
# remove ourselves from the channel list
channels.remove(self)
except ValueError:
# not there somehow lol
return
else:
# add ourselves at our designated position
channels.insert(position, self)
def _update(self, server, data):
self.server = server
self.name = data['name']
self.topic = data.get('topic')
self.position = data['position']
self.bitrate = data.get('bitrate')
self.type = data['type']
self.user_limit = data.get('user_limit')
self._permission_overwrites = []
payload = [{'id': c.id, 'position': index } for index, c in enumerate(channels)]
yield from http.patch(url, json=payload, bucket='move_channel')
def _fill_overwrites(self, data):
self._overwrites = []
everyone_index = 0
everyone_id = self.server.id
for index, overridden in enumerate(data.get('permission_overwrites', [])):
overridden_id = int(overridden.pop('id'))
self._permission_overwrites.append(Overwrites(id=overridden_id, **overridden))
self._overwrites.append(Overwrites(id=overridden_id, **overridden))
if overridden['type'] == 'member':
continue
@ -122,7 +93,7 @@ class Channel(Hashable):
everyone_index = index
# do the swap
tmp = self._permission_overwrites
tmp = self._overwrites
if tmp:
tmp[everyone_index], tmp[0] = tmp[0], tmp[everyone_index]
@ -131,7 +102,7 @@ class Channel(Hashable):
"""Returns a list of :class:`Roles` that have been overridden from
their default values in the :attr:`Server.roles` attribute."""
ret = []
for overwrite in filter(lambda o: o.type == 'role', self._permission_overwrites):
for overwrite in filter(lambda o: o.type == 'role', self._overwrites):
role = utils.get(self.server.roles, id=overwrite.id)
if role is None:
continue
@ -146,10 +117,6 @@ class Channel(Hashable):
"""bool : Indicates if this is the default channel for the :class:`Server` it belongs to."""
return self.server.id == self.id
@property
def is_private(self):
return False
@property
def mention(self):
"""str : The string that allows you to mention the channel."""
@ -182,7 +149,7 @@ class Channel(Hashable):
else:
predicate = lambda p: True
for overwrite in filter(predicate, self._permission_overwrites):
for overwrite in filter(predicate, self._overwrites):
if overwrite.id == obj.id:
allow = Permissions(overwrite.allow)
deny = Permissions(overwrite.deny)
@ -276,7 +243,7 @@ class Channel(Hashable):
allows = 0
# Apply channel specific role permission overwrites
for overwrite in self._permission_overwrites:
for overwrite in self._overwrites:
if overwrite.type == 'role' and overwrite.id in member_role_ids:
denies |= overwrite.deny
allows |= overwrite.allow
@ -284,7 +251,7 @@ class Channel(Hashable):
base.handle_overwrite(allow=allows, deny=denies)
# Apply member specific permission overwrites
for overwrite in self._permission_overwrites:
for overwrite in self._overwrites:
if overwrite.type == 'member' and overwrite.id == member.id:
base.handle_overwrite(allow=overwrite.allow, deny=overwrite.deny)
break
@ -307,14 +274,286 @@ class Channel(Hashable):
base.value &= ~denied.value
# text channels do not have voice related permissions
if self.type is ChannelType.text:
if isinstance(self, TextChannel):
denied = Permissions.voice()
base.value &= ~denied.value
return base
class PrivateChannel(Hashable):
"""Represents a Discord private channel.
@asyncio.coroutine
def delete(self):
"""|coro|
Deletes the channel.
You must have Manage Channel permission to use this.
Raises
-------
Forbidden
You do not have proper permissions to delete the channel.
NotFound
The channel was not found or was already deleted.
HTTPException
Deleting the channel failed.
"""
yield from self._state.http.delete_channel(self.id)
class TextChannel(abc.MessageChannel, CommonGuildChannel):
"""Represents a Discord server text channel.
Supported Operations:
+-----------+---------------------------------------+
| Operation | Description |
+===========+=======================================+
| x == y | Checks if two channels are equal. |
+-----------+---------------------------------------+
| x != y | Checks if two channels are not equal. |
+-----------+---------------------------------------+
| hash(x) | Returns the channel's hash. |
+-----------+---------------------------------------+
| str(x) | Returns the channel's name. |
+-----------+---------------------------------------+
Attributes
-----------
name: str
The channel name.
server: :class:`Server`
The server the channel belongs to.
id: int
The channel ID.
topic: Optional[str]
The channel's topic. None if it doesn't exist.
position: int
The position in the channel list. This is a number that starts at 0. e.g. the
top channel is position 0.
"""
__slots__ = ( 'name', 'id', 'server', 'topic', '_state',
'position', '_overwrites' )
def __init__(self, *, state, server, data):
self._state = state
self.id = int(data['id'])
self._update(server, data)
def _update(self, server, data):
self.server = server
self.name = data['name']
self.topic = data.get('topic')
self.position = data['position']
self._fill_overwrites(data)
def _get_destination(self):
return self.id, self.server.id
@asyncio.coroutine
def edit(self, **options):
"""|coro|
Edits the channel.
You must have the Manage Channel permission to use this.
Parameters
----------
name: str
The new channel name.
topic: str
The new channel's topic.
position: int
The new channel's position.
Raises
------
InvalidArgument
If position is less than 0 or greater than the number of channels.
Forbidden
You do not have permissions to edit the channel.
HTTPException
Editing the channel failed.
"""
try:
position = options.pop('position')
except KeyError:
pass
else:
yield from self._move(position)
self.position = position
if options:
data = yield from self._state.http.edit_channel(self.id, **options)
self._update(self.server, data)
class VoiceChannel(CommonGuildChannel):
"""Represents a Discord server voice channel.
Supported Operations:
+-----------+---------------------------------------+
| Operation | Description |
+===========+=======================================+
| x == y | Checks if two channels are equal. |
+-----------+---------------------------------------+
| x != y | Checks if two channels are not equal. |
+-----------+---------------------------------------+
| hash(x) | Returns the channel's hash. |
+-----------+---------------------------------------+
| str(x) | Returns the channel's name. |
+-----------+---------------------------------------+
Attributes
-----------
name: str
The channel name.
server: :class:`Server`
The server the channel belongs to.
id: int
The channel ID.
position: int
The position in the channel list. This is a number that starts at 0. e.g. the
top channel is position 0.
bitrate: int
The channel's preferred audio bitrate in bits per second.
voice_members
A list of :class:`Members` that are currently inside this voice channel.
user_limit: int
The channel's limit for number of members that can be in a voice channel.
"""
__slots__ = ( 'voice_members', 'name', 'id', 'server', 'bitrate',
'user_limit', '_state', 'position', '_overwrites' )
def __init__(self, *, state, server, data):
self._state = state
self.id = int(data['id'])
self._update(server, data)
self.voice_members = []
def _update(self, server, data):
self.server = server
self.name = data['name']
self.position = data['position']
self.bitrate = data.get('bitrate')
self.user_limit = data.get('user_limit')
self._fill_overwrites(data)
@asyncio.coroutine
def edit(self, **options):
"""|coro|
Edits the channel.
You must have the Manage Channel permission to use this.
Parameters
----------
bitrate: int
The new channel's bitrate.
user_limit: int
The new channel's user limit.
position: int
The new channel's position.
Raises
------
Forbidden
You do not have permissions to edit the channel.
HTTPException
Editing the channel failed.
"""
try:
position = options.pop('position')
except KeyError:
pass
else:
yield from self._move(position)
self.position = position
if options:
data = yield from self._state.http.edit_channel(self.id, **options)
self._update(self.server, data)
class DMChannel(abc.MessageChannel, Hashable):
"""Represents a Discord direct message channel.
Supported Operations:
+-----------+-------------------------------------------------+
| Operation | Description |
+===========+=================================================+
| x == y | Checks if two channels are equal. |
+-----------+-------------------------------------------------+
| x != y | Checks if two channels are not equal. |
+-----------+-------------------------------------------------+
| hash(x) | Returns the channel's hash. |
+-----------+-------------------------------------------------+
| str(x) | Returns a string representation of the channel |
+-----------+-------------------------------------------------+
Attributes
----------
recipient: :class:`User`
The user you are participating with in the direct message channel.
me: :class:`User`
The user presenting yourself.
id: int
The direct message channel ID.
"""
__slots__ = ('id', 'recipient', 'me', '_state')
def __init__(self, *, me, state, data):
self._state = state
self.recipient = state.try_insert_user(data['recipients'][0])
self.me = me
self.id = int(data['id'])
def _get_destination(self):
return self.id, None
def __str__(self):
return 'Direct Message with %s' % self.recipient
@property
def created_at(self):
"""Returns the direct message channel's creation time in UTC."""
return utils.snowflake_time(self.id)
def permissions_for(self, user=None):
"""Handles permission resolution for a :class:`User`.
This function is there for compatibility with other channel types.
Actual direct messages do not really have the concept of permissions.
This returns all the Text related permissions set to true except:
- send_tts_messages: You cannot send TTS messages in a DM.
- manage_messages: You cannot delete others messages in a DM.
Parameters
-----------
user: :class:`User`
The user to check permissions for. This parameter is ignored
but kept for compatibility.
Returns
--------
:class:`Permissions`
The resolved permissions.
"""
base = Permissions.text()
base.send_tts_messages = False
base.manage_messages = False
return base
class GroupChannel(abc.MessageChannel, Hashable):
"""Represents a Discord group channel.
Supported Operations:
@ -333,50 +572,42 @@ class PrivateChannel(Hashable):
Attributes
----------
recipients: list of :class:`User`
The users you are participating with in the private channel.
The users you are participating with in the group channel.
me: :class:`User`
The user presenting yourself.
id: int
The private channel ID.
is_private: bool
``True`` if the channel is a private channel (i.e. PM). ``True`` in this case.
type: :class:`ChannelType`
The type of private channel.
owner: Optional[:class:`User`]
The user that owns the private channel. If the channel type is not
:attr:`ChannelType.group` then this is always ``None``.
The group channel ID.
owner: :class:`User`
The user that owns the group channel.
icon: Optional[str]
The private channel's icon hash. If the channel type is not
:attr:`ChannelType.group` then this is always ``None``.
The group channel's icon hash if provided.
name: Optional[str]
The private channel's name. If the channel type is not
:attr:`ChannelType.group` then this is always ``None``.
The group channel's name if provided.
"""
__slots__ = ('id', 'recipients', 'type', 'owner', 'icon', 'name', 'me', '_state')
__slots__ = ('id', 'recipients', 'owner', 'icon', 'name', 'me', '_state')
def __init__(self, *, me, state, data):
self._state = state
self.recipients = [state.try_insert_user(u) for u in data['recipients']]
self.id = int(data['id'])
self.me = me
self.type = try_enum(ChannelType, data['type'])
self._update_group(data)
def _update_group(self, data):
owner_id = utils._get_as_snowflake(data, 'owner_id')
self.icon = data.get('icon')
self.name = data.get('name')
if owner_id == self.me.id:
self.owner = self.me
else:
self.owner = utils.find(lambda u: u.id == owner_id, self.recipients)
@property
def is_private(self):
return True
def _get_destination(self):
return self.id, None
def __str__(self):
if self.type is ChannelType.private:
return 'Direct Message with {0.name}'.format(self.user)
if self.name:
return self.name
@ -385,15 +616,6 @@ class PrivateChannel(Hashable):
return ', '.join(map(lambda x: x.name, self.recipients))
@property
def user(self):
"""A property that returns the first recipient of the private channel.
This is mainly for compatibility and ease of use with old style private
channels that had a single recipient.
"""
return self.recipients[0]
@property
def icon_url(self):
"""Returns the channel's icon URL if available or an empty string otherwise."""
@ -404,23 +626,22 @@ class PrivateChannel(Hashable):
@property
def created_at(self):
"""Returns the private channel's creation time in UTC."""
"""Returns the channel's creation time in UTC."""
return utils.snowflake_time(self.id)
def permissions_for(self, user):
"""Handles permission resolution for a :class:`User`.
This function is there for compatibility with :class:`Channel`.
This function is there for compatibility with other channel types.
Actual private messages do not really have the concept of permissions.
Actual direct messages do not really have the concept of permissions.
This returns all the Text related permissions set to true except:
- send_tts_messages: You cannot send TTS messages in a PM.
- manage_messages: You cannot delete others messages in a PM.
- send_tts_messages: You cannot send TTS messages in a DM.
- manage_messages: You cannot delete others messages in a DM.
This also handles permissions for :attr:`ChannelType.group` channels
such as kicking or mentioning everyone.
This also checks the kick_members permission if the user is the owner.
Parameters
-----------
@ -436,11 +657,22 @@ class PrivateChannel(Hashable):
base = Permissions.text()
base.send_tts_messages = False
base.manage_messages = False
base.mention_everyone = self.type is ChannelType.group
base.mention_everyone = True
if user == self.owner:
if user.id == self.owner.id:
base.kick_members = True
return base
def _channel_factory(channel_type):
value = try_enum(ChannelType, channel_type)
if value is ChannelType.text:
return TextChannel, value
elif value is ChannelType.voice:
return VoiceChannel, value
elif value is ChannelType.private:
return DMChannel, value
elif value is ChannelType.group:
return GroupChannel, value
else:
return None, value

8
discord/client.py

@ -27,7 +27,7 @@ DEALINGS IN THE SOFTWARE.
from . import __version__ as library_version
from .user import User
from .member import Member
from .channel import Channel, PrivateChannel
from .channel import *
from .server import Server
from .message import Message
from .invite import Invite
@ -261,9 +261,9 @@ class Client:
@asyncio.coroutine
def _resolve_destination(self, destination):
if isinstance(destination, Channel):
if isinstance(destination, TextChannel):
return destination.id, destination.server.id
elif isinstance(destination, PrivateChannel):
elif isinstance(destination, DMChannel):
return destination.id, None
elif isinstance(destination, Server):
return destination.id, destination.id
@ -283,7 +283,7 @@ class Client:
# couldn't find it in cache so YOLO
return destination.id, destination.id
else:
fmt = 'Destination must be Channel, PrivateChannel, User, or Object. Received {0.__class__.__name__}'
fmt = 'Destination must be TextChannel, DMChannel, User, or Object. Received {0.__class__.__name__}'
raise InvalidArgument(fmt.format(destination))
def __getattr__(self, name):

6
discord/errors.py

@ -38,6 +38,12 @@ class ClientException(DiscordException):
"""
pass
class NoMoreMessages(DiscordException):
"""Exception that is thrown when a ``history`` operation has no more
messages. This is only exposed for Python 3.4 only.
"""
pass
class GatewayNotFound(DiscordException):
"""An exception that is usually thrown when the gateway hub
for the :class:`Client` websocket is not found."""

62
discord/iterators.py

@ -27,23 +27,26 @@ DEALINGS IN THE SOFTWARE.
import sys
import asyncio
import aiohttp
import datetime
from .errors import NoMoreMessages
from .utils import time_snowflake
from .message import Message
from .object import Object
PY35 = sys.version_info >= (3, 5)
class LogsFromIterator:
"""Iterator for recieving logs.
"""Iterator for receiving logs.
The messages endpoint has two behaviors we care about here:
The messages endpoint has two behaviours we care about here:
If `before` is specified, the messages endpoint returns the `limit`
newest messages before `before`, sorted with newest first. For filling over
100 messages, update the `before` parameter to the oldest message recieved.
100 messages, update the `before` parameter to the oldest message received.
Messages will be returned in order by time.
If `after` is specified, it returns the `limit` oldest messages after
`after`, sorted with newest first. For filling over 100 messages, update the
`after` parameter to the newest message recieved. If messages are not
`after` parameter to the newest message received. If messages are not
reversed, they will be out of order (99-0, 199-100, so on)
A note that if both before and after are specified, before is ignored by the
@ -51,7 +54,6 @@ class LogsFromIterator:
Parameters
-----------
client : class:`Client`
channel: class:`Channel`
Channel from which to request logs
limit : int
@ -66,21 +68,34 @@ class LogsFromIterator:
reverse: bool
If set to true, return messages in oldest->newest order. Recommended
when using with "after" queries with limit over 100, otherwise messages
will be out of order. Defaults to False for backwards compatability.
will be out of order.
"""
def __init__(self, client, channel, limit,
before=None, after=None, around=None, reverse=False):
self.client = client
def __init__(self, channel, limit,
before=None, after=None, around=None, reverse=None):
if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False))
if isinstance(after, datetime.datetime):
after = Object(id=time_snowflake(after, high=True))
if isinstance(around, datetime.datetime):
around = Object(id=time_snowflake(around))
self.channel = channel
self.ctx = channel._state
self.logs_from = channel._state.http.logs_from
self.limit = limit
self.before = before
self.after = after
self.around = around
if reverse is None:
self.reverse = after is not None
else:
self.reverse = reverse
self._filter = None # message dict -> bool
self.messages = asyncio.Queue()
self.ctx = client.connection.ctx
if self.around:
if self.limit > 101:
@ -92,29 +107,32 @@ class LogsFromIterator:
self._retrieve_messages = self._retrieve_messages_around_strategy
if self.before and self.after:
self._filter = lambda m: self.after.id < m['id'] < self.before.id
self._filter = lambda m: self.after.id < int(m['id']) < self.before.id
elif self.before:
self._filter = lambda m: m['id'] < self.before.id
self._filter = lambda m: int(m['id']) < self.before.id
elif self.after:
self._filter = lambda m: self.after.id < m['id']
self._filter = lambda m: self.after.id < int(m['id'])
elif self.before and self.after:
if self.reverse:
self._retrieve_messages = self._retrieve_messages_after_strategy
self._filter = lambda m: m['id'] < self.before.id
self._filter = lambda m: int(m['id']) < self.before.id
else:
self._retrieve_messages = self._retrieve_messages_before_strategy
self._filter = lambda m: m['id'] > self.after.id
self._filter = lambda m: int(m['id']) > self.after.id
elif self.after:
self._retrieve_messages = self._retrieve_messages_after_strategy
else:
self._retrieve_messages = self._retrieve_messages_before_strategy
@asyncio.coroutine
def iterate(self):
def get(self):
if self.messages.empty():
yield from self.fill_messages()
try:
return self.messages.get_nowait()
except asyncio.QueueEmpty:
raise NoMoreMessages()
@asyncio.coroutine
def fill_messages(self):
@ -136,7 +154,7 @@ class LogsFromIterator:
@asyncio.coroutine
def _retrieve_messages_before_strategy(self, retrieve):
"""Retrieve messages using before parameter."""
data = yield from self.client._logs_from(self.channel, retrieve, before=self.before)
data = yield from self.logs_from(self.channel.id, retrieve, before=getattr(self.before, 'id', None))
if len(data):
self.limit -= retrieve
self.before = Object(id=int(data[-1]['id']))
@ -145,7 +163,7 @@ class LogsFromIterator:
@asyncio.coroutine
def _retrieve_messages_after_strategy(self, retrieve):
"""Retrieve messages using after parameter."""
data = yield from self.client._logs_from(self.channel, retrieve, after=self.after)
data = yield from self.logs_from(self.channel.id, retrieve, after=getattr(self.after, 'id', None))
if len(data):
self.limit -= retrieve
self.after = Object(id=int(data[0]['id']))
@ -155,7 +173,7 @@ class LogsFromIterator:
def _retrieve_messages_around_strategy(self, retrieve):
"""Retrieve messages using around parameter."""
if self.around:
data = yield from self.client._logs_from(self.channel, retrieve, around=self.around)
data = yield from self.logs_from(self.channel.id, retrieve, around=getattr(self.around, 'id', None))
self.around = None
return data
return []
@ -168,9 +186,9 @@ class LogsFromIterator:
@asyncio.coroutine
def __anext__(self):
try:
msg = yield from self.iterate()
msg = yield from self.get()
return msg
except asyncio.QueueEmpty:
except NoMoreMessages:
# if we're still empty at this point...
# we didn't get any new messages so stop looping
raise StopAsyncIteration()

4
discord/message.py

@ -24,9 +24,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from . import utils
from .user import User
from .reaction import Reaction
from . import utils, abc
from .object import Object
from .calls import CallMessage
import re
@ -292,7 +292,7 @@ class Message:
self.channel.is_private = True
return
if not self.channel.is_private:
if isinstance(self.channel, abc.GuildChannel):
self.server = self.channel.server
found = self.server.get_member(self.author.id)
if found is not None:

10
discord/server.py

@ -29,8 +29,8 @@ from .role import Role
from .member import Member, VoiceState
from .emoji import Emoji
from .game import Game
from .channel import Channel
from .enums import ServerRegion, Status, try_enum, VerificationLevel
from .channel import *
from .enums import ServerRegion, Status, ChannelType, try_enum, VerificationLevel
from .mixins import Hashable
import copy
@ -273,7 +273,11 @@ class Server(Hashable):
if 'channels' in data:
channels = data['channels']
for c in channels:
channel = Channel(server=self, data=c, state=self._state)
if c['type'] == ChannelType.text.value:
channel = TextChannel(server=self, data=c, state=self._state)
else:
channel = VoiceChannel(server=self, data=c, state=self._state)
self._add_channel(channel)
@utils.cached_slot_property('_default_role')

40
discord/state.py

@ -30,7 +30,7 @@ from .game import Game
from .emoji import Emoji
from .reaction import Reaction
from .message import Message
from .channel import Channel, PrivateChannel
from .channel import *
from .member import Member
from .role import Role
from . import utils, compat
@ -153,13 +153,13 @@ class ConnectionState:
def _add_private_channel(self, channel):
self._private_channels[channel.id] = channel
if channel.type is ChannelType.private:
self._private_channels_by_user[channel.user.id] = channel
if isinstance(channel, DMChannel):
self._private_channels_by_user[channel.recipient.id] = channel
def _remove_private_channel(self, channel):
self._private_channels.pop(channel.id, None)
if channel.type is ChannelType.private:
self._private_channels_by_user.pop(channel.user.id, None)
if isinstance(channel, DMChannel):
self._private_channels_by_user.pop(channel.recipient.id, None)
def _get_message(self, msg_id):
return utils.find(lambda m: m.id == msg_id, self.messages)
@ -229,7 +229,8 @@ class ConnectionState:
servers.append(server)
for pm in data.get('private_channels'):
self._add_private_channel(PrivateChannel(me=self.user, data=pm, state=self.ctx))
factory, _ = _channel_factory(pm['type'])
self._add_private_channel(factory(me=self.user, data=pm, state=self.ctx))
compat.create_task(self._delay_ready(), loop=self.loop)
@ -348,13 +349,18 @@ class ConnectionState:
self.user = User(state=self.ctx, data=data)
def parse_channel_delete(self, data):
server = self._get_server(int(data['guild_id']))
server = self._get_server(utils._get_as_snowflake(data, 'guild_id'))
channel_id = int(data['id'])
if server is not None:
channel_id = data.get('id')
channel = server.get_channel(channel_id)
if channel is not None:
server._remove_channel(channel)
self.dispatch('channel_delete', channel)
else:
# the reason we're doing this is so it's also removed from the
# private channel by user cache as well
channel = self._get_private_channel(channel_id)
self._remove_private_channel(channel)
def parse_channel_update(self, data):
channel_type = try_enum(ChannelType, data.get('type'))
@ -375,15 +381,15 @@ class ConnectionState:
self.dispatch('channel_update', old_channel, channel)
def parse_channel_create(self, data):
ch_type = try_enum(ChannelType, data.get('type'))
factory, ch_type = _channel_factory(data['type'])
channel = None
if ch_type in (ChannelType.group, ChannelType.private):
channel = PrivateChannel(me=self.user, data=data, state=self.ctx)
channel = factory(me=self.user, data=data, state=self.ctx)
self._add_private_channel(channel)
else:
server = self._get_server(utils._get_as_snowflake(data, 'guild_id'))
if server is not None:
channel = Channel(server=server, state=self.ctx, data=data)
channel = factory(server=server, state=self.ctx, data=data)
server._add_channel(channel)
self.dispatch('channel_create', channel)
@ -638,14 +644,12 @@ class ConnectionState:
if channel is not None:
member = None
user_id = utils._get_as_snowflake(data, 'user_id')
is_private = getattr(channel, 'is_private', None)
if is_private == None:
return
if is_private:
member = channel.user
else:
if isinstance(channel, DMChannel):
member = channel.recipient
elif isinstance(channel, TextChannel):
member = channel.server.get_member(user_id)
elif isinstance(channel, GroupChannel):
member = utils.find(lambda x: x.id == user_id, channel.recipients)
if member is not None:
timestamp = datetime.datetime.utcfromtimestamp(data.get('timestamp'))

Loading…
Cancel
Save