Browse Source

Move away from StateContext and use ConnectionState directly.

pull/447/head
Rapptz 8 years ago
parent
commit
5e6bfecb07
  1. 4
      discord/client.py
  2. 4
      discord/message.py
  3. 2
      discord/reaction.py
  4. 40
      discord/state.py

4
discord/client.py

@ -1201,7 +1201,7 @@ class Client:
data = yield from self.http.application_info() data = yield from self.http.application_info()
return AppInfo(id=data['id'], name=data['name'], return AppInfo(id=data['id'], name=data['name'],
description=data['description'], icon=data['icon'], description=data['description'], icon=data['icon'],
owner=User(state=self.connection.ctx, data=data['owner'])) owner=User(state=self.connection, data=data['owner']))
@asyncio.coroutine @asyncio.coroutine
def get_user_info(self, user_id): def get_user_info(self, user_id):
@ -1230,4 +1230,4 @@ class Client:
Fetching the user failed. Fetching the user failed.
""" """
data = yield from self.http.get_user_info(user_id) data = yield from self.http.get_user_info(user_id)
return User(state=self.connection.ctx, data=data) return User(state=self.connection, data=data)

4
discord/message.py

@ -135,7 +135,7 @@ class Message:
setattr(self, key, transform(value)) setattr(self, key, transform(value))
def _add_reaction(self, data): def _add_reaction(self, data):
emoji = self._state.reaction_emoji(data['emoji']) emoji = self._state.get_reaction_emoji(data['emoji'])
reaction = discord.utils.find(lambda r: r.emoji == emoji, self.reactions) reaction = discord.utils.find(lambda r: r.emoji == emoji, self.reactions)
is_me = data['me'] = int(data['user_id']) == self._state.self_id is_me = data['me'] = int(data['user_id']) == self._state.self_id
@ -150,7 +150,7 @@ class Message:
return reaction return reaction
def _remove_reaction(self, data): def _remove_reaction(self, data):
emoji = self._state.reaction_emoji(data['emoji']) emoji = self._state.get_reaction_emoji(data['emoji'])
reaction = discord.utils.find(lambda r: r.emoji == emoji, self.reactions) reaction = discord.utils.find(lambda r: r.emoji == emoji, self.reactions)
if reaction is None: if reaction is None:

2
discord/reaction.py

@ -63,7 +63,7 @@ class Reaction:
def __init__(self, *, message, data, emoji=None): def __init__(self, *, message, data, emoji=None):
self.message = message self.message = message
self.emoji = message._state.reaction_emoji(data['emoji']) if emoji is None else emoji self.emoji = message._state.get_reaction_emoji(data['emoji']) if emoji is None else emoji
self.count = data.get('count', 1) self.count = data.get('count', 1)
self.me = data.get('me') self.me = data.get('me')

40
discord/state.py

@ -52,26 +52,16 @@ Listener = namedtuple('Listener', ('type', 'future', 'predicate'))
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ReadyState = namedtuple('ReadyState', ('launch', 'guilds')) ReadyState = namedtuple('ReadyState', ('launch', 'guilds'))
class StateContext:
__slots__ = ('store_user', 'http', 'self_id', 'store_emoji', 'reaction_emoji', 'loop')
def __init__(self, **kwargs):
for attr, value in kwargs.items():
setattr(self, attr, value)
class ConnectionState: class ConnectionState:
def __init__(self, *, dispatch, chunker, syncer, http, loop, **options): def __init__(self, *, dispatch, chunker, syncer, http, loop, **options):
self.loop = loop self.loop = loop
self.http = http
self.max_messages = max(options.get('max_messages', 5000), 100) self.max_messages = max(options.get('max_messages', 5000), 100)
self.dispatch = dispatch self.dispatch = dispatch
self.chunker = chunker self.chunker = chunker
self.syncer = syncer self.syncer = syncer
self.is_bot = None self.is_bot = None
self._listeners = [] self._listeners = []
self.ctx = StateContext(store_user=self.store_user,
store_emoji=self.store_emoji,
reaction_emoji=self._get_reaction_emoji,
http=http, self_id=None, loop=loop)
self.clear() self.clear()
def clear(self): def clear(self):
@ -114,6 +104,11 @@ class ConnectionState:
for index in reversed(removed): for index in reversed(removed):
del self._listeners[index] del self._listeners[index]
@property
def self_id(self):
u = self.user
return u.id if u else None
@property @property
def voice_clients(self): def voice_clients(self):
return list(self._voice_clients.values()) return list(self._voice_clients.values())
@ -137,7 +132,7 @@ class ConnectionState:
try: try:
return self._users[user_id] return self._users[user_id]
except KeyError: except KeyError:
self._users[user_id] = user = User(state=self.ctx, data=data) self._users[user_id] = user = User(state=self, data=data)
return user return user
def store_emoji(self, guild, data): def store_emoji(self, guild, data):
@ -145,7 +140,7 @@ class ConnectionState:
try: try:
return self._emojis[emoji_id] return self._emojis[emoji_id]
except KeyError: except KeyError:
self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self.ctx, data=data) self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self, data=data)
return emoji return emoji
@property @property
@ -185,7 +180,7 @@ class ConnectionState:
return discord.utils.find(lambda m: m.id == msg_id, self.messages) return discord.utils.find(lambda m: m.id == msg_id, self.messages)
def _add_guild_from_data(self, guild): def _add_guild_from_data(self, guild):
guild = Guild(data=guild, state=self.ctx) guild = Guild(data=guild, state=self)
Guild.me = property(lambda s: s.get_member(self.user.id)) Guild.me = property(lambda s: s.get_member(self.user.id))
Guild.voice_client = property(lambda s: self._get_voice_client(s.id)) Guild.voice_client = property(lambda s: self._get_voice_client(s.id))
self._add_guild(guild) self._add_guild(guild)
@ -240,7 +235,6 @@ class ConnectionState:
def parse_ready(self, data): def parse_ready(self, data):
self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[]) self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[])
self.user = self.store_user(data['user']) self.user = self.store_user(data['user'])
self.ctx.self_id = self.user.id
guilds = data.get('guilds') guilds = data.get('guilds')
guilds = self._ready_state.guilds guilds = self._ready_state.guilds
@ -251,7 +245,7 @@ class ConnectionState:
for pm in data.get('private_channels'): for pm in data.get('private_channels'):
factory, _ = _channel_factory(pm['type']) factory, _ = _channel_factory(pm['type'])
self._add_private_channel(factory(me=self.user, data=pm, state=self.ctx)) self._add_private_channel(factory(me=self.user, data=pm, state=self))
discord.compat.create_task(self._delay_ready(), loop=self.loop) discord.compat.create_task(self._delay_ready(), loop=self.loop)
@ -260,7 +254,7 @@ class ConnectionState:
def parse_message_create(self, data): def parse_message_create(self, data):
channel = self.get_channel(int(data['channel_id'])) channel = self.get_channel(int(data['channel_id']))
message = Message(channel=channel, data=data, state=self.ctx) message = Message(channel=channel, data=data, state=self)
self.dispatch('message', message) self.dispatch('message', message)
self.messages.append(message) self.messages.append(message)
@ -341,7 +335,7 @@ class ConnectionState:
self.dispatch('member_update', old_member, member) self.dispatch('member_update', old_member, member)
def parse_user_update(self, data): def parse_user_update(self, data):
self.user = User(state=self.ctx, data=data) self.user = User(state=self, data=data)
def parse_channel_delete(self, data): def parse_channel_delete(self, data):
guild = self._get_guild(discord.utils._get_as_snowflake(data, 'guild_id')) guild = self._get_guild(discord.utils._get_as_snowflake(data, 'guild_id'))
@ -379,12 +373,12 @@ class ConnectionState:
factory, ch_type = _channel_factory(data['type']) factory, ch_type = _channel_factory(data['type'])
channel = None channel = None
if ch_type in (ChannelType.group, ChannelType.private): if ch_type in (ChannelType.group, ChannelType.private):
channel = factory(me=self.user, data=data, state=self.ctx) channel = factory(me=self.user, data=data, state=self)
self._add_private_channel(channel) self._add_private_channel(channel)
else: else:
guild = self._get_guild(discord.utils._get_as_snowflake(data, 'guild_id')) guild = self._get_guild(discord.utils._get_as_snowflake(data, 'guild_id'))
if guild is not None: if guild is not None:
channel = factory(guild=guild, state=self.ctx, data=data) channel = factory(guild=guild, state=self, data=data)
guild._add_channel(channel) guild._add_channel(channel)
self.dispatch('channel_create', channel) self.dispatch('channel_create', channel)
@ -413,7 +407,7 @@ class ConnectionState:
roles.append(role) roles.append(role)
data['roles'] = sorted(roles, key=lambda r: r.id) data['roles'] = sorted(roles, key=lambda r: r.id)
return Member(guild=guild, data=data, state=self.ctx) return Member(guild=guild, data=data, state=self)
def parse_guild_member_add(self, data): def parse_guild_member_add(self, data):
guild = self._get_guild(int(data['guild_id'])) guild = self._get_guild(int(data['guild_id']))
@ -577,7 +571,7 @@ class ConnectionState:
def parse_guild_role_create(self, data): def parse_guild_role_create(self, data):
guild = self._get_guild(int(data['guild_id'])) guild = self._get_guild(int(data['guild_id']))
role_data = data['role'] role_data = data['role']
role = Role(guild=guild, data=role_data, state=self.ctx) role = Role(guild=guild, data=role_data, state=self)
guild._add_role(role) guild._add_role(role)
self.dispatch('guild_role_create', role) self.dispatch('guild_role_create', role)
@ -679,7 +673,7 @@ class ConnectionState:
else: else:
return None return None
def _get_reaction_emoji(self, data): def get_reaction_emoji(self, data):
emoji_id = discord.utils._get_as_snowflake(data, 'id') emoji_id = discord.utils._get_as_snowflake(data, 'id')
if not emoji_id: if not emoji_id:

Loading…
Cancel
Save