diff --git a/discord/channel.py b/discord/channel.py index 67a66045d..8968b42da 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -108,7 +108,7 @@ class Channel(Hashable): self._permission_overwrites = [] everyone_index = 0 - everyone_id = self.server.default_role.id + everyone_id = self.server.id for index, overridden in enumerate(kwargs.get('permission_overwrites', [])): overridden_id = overridden['id'] diff --git a/discord/client.py b/discord/client.py index 690a7f6d2..02bd5d6fe 100644 --- a/discord/client.py +++ b/discord/client.py @@ -138,7 +138,8 @@ class Client: if max_messages is None or max_messages < 100: max_messages = 5000 - self.connection = ConnectionState(self.dispatch, self.request_offline_members, max_messages, loop=self.loop) + self.connection = ConnectionState(self.dispatch, self.request_offline_members, + self._syncer, max_messages, loop=self.loop) connector = options.pop('connector', None) self.http = HTTPClient(connector, loop=self.loop) @@ -149,6 +150,10 @@ class Client: # internals + @asyncio.coroutine + def _syncer(self, guilds): + yield from self.ws.request_sync(guilds) + def _get_cache_filename(self, email): filename = hashlib.md5(email.encode('utf-8')).hexdigest() return os.path.join(tempfile.gettempdir(), 'discord_py', filename) @@ -295,12 +300,16 @@ class Client: @asyncio.coroutine def _login_1(self, token, **kwargs): log.info('logging in using static token') - yield from self.http.static_login(token, bot=kwargs.pop('bot', True)) + is_bot = kwargs.pop('bot', True) + yield from self.http.static_login(token, bot=is_bot) + self.connection.is_bot = is_bot self._is_logged_in.set() @asyncio.coroutine def _login_2(self, email, password, **kwargs): # attempt to read the token from cache + self.connection.is_bot = False + if self.cache_auth: token = self._get_cache_token(email, password) try: diff --git a/discord/gateway.py b/discord/gateway.py index 91f8e0789..c605cb95f 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -127,6 +127,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): INVALIDATE_SESSION Receive only. Tells the client to invalidate the session and IDENTIFY again. + HELLO + Receive only. Tells the client the heartbeat interval. + HEARTBEAT_ACK + Receive only. Confirms receiving of a heartbeat. Not having it implies + a connection issue. + GUILD_SYNC + Send only. Requests a guild sync. gateway The gateway we are currently connected to. token @@ -143,6 +150,9 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): RECONNECT = 7 REQUEST_MEMBERS = 8 INVALIDATE_SESSION = 9 + HELLO = 10 + HEARTBEAT_ACK = 11 + GUILD_SYNC = 12 def __init__(self, *args, **kwargs): super().__init__(*args, max_size=None, **kwargs) @@ -172,6 +182,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): client.connection._update_references(ws) log.info('Created websocket connected to {}'.format(gateway)) + + # poll the event for OP HELLO + yield from ws.poll_event() + if not resume: yield from ws.identify() log.info('sent the identify payload to create the websocket') @@ -232,6 +246,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): 'v': 3 } } + + if not self._connection.is_bot: + payload['d']['synced_guilds'] = [] + yield from self.send_as_json(payload) @asyncio.coroutine @@ -277,6 +295,12 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): yield from self.close() raise ReconnectWebSocket() + if op == self.HELLO: + interval = data['heartbeat_interval'] / 1000.0 + self._keep_alive = KeepAliveHandler(ws=self, interval=interval) + self._keep_alive.start() + return + if op == self.INVALIDATE_SESSION: state.sequence = None state.session_id = None @@ -298,11 +322,6 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): state.sequence = msg['s'] state.session_id = data['session_id'] - if is_ready or event == 'RESUMED': - interval = data['heartbeat_interval'] / 1000.0 - self._keep_alive = KeepAliveHandler(ws=self, interval=interval) - self._keep_alive.start() - parser = 'parse_' + event.lower() try: @@ -400,6 +419,14 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): status = Status.idle if idle_since else Status.online me.status = status + @asyncio.coroutine + def request_sync(self, guild_ids): + payload = { + 'op': self.GUILD_SYNC, + 'd': list(guild_ids) + } + yield from self.send_as_json(payload) + @asyncio.coroutine def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): payload = { diff --git a/discord/http.py b/discord/http.py index d7e90b880..992d76350 100644 --- a/discord/http.py +++ b/discord/http.py @@ -497,4 +497,4 @@ class HTTPClient: data = yield from self.get(self.GATEWAY, bucket=_func_()) except HTTPException as e: raise GatewayNotFound() from e - return data.get('url') + '?encoding=json&v=4' + return data.get('url') + '?encoding=json&v=5' diff --git a/discord/server.py b/discord/server.py index e330fe6b9..02873c464 100644 --- a/discord/server.py +++ b/discord/server.py @@ -169,7 +169,6 @@ class Server(Hashable): self._member_count = member_count self.name = guild.get('name') - self.large = guild.get('large', None if member_count is None else self._member_count > 250) self.region = guild.get('region') try: self.region = ServerRegion(self.region) @@ -181,24 +180,36 @@ class Server(Hashable): self.unavailable = guild.get('unavailable', False) self.id = guild['id'] self.roles = [Role(server=self, **r) for r in guild.get('roles', [])] + self._sync(guild) + self.large = None if member_count is None else self._member_count > 250 - for data in guild.get('members', []): + if 'owner_id' in guild: + self.owner_id = guild['owner_id'] + self.owner = self.get_member(self.owner_id) + + afk_id = guild.get('afk_channel_id') + self.afk_channel = self.get_channel(afk_id) + + for obj in guild.get('voice_states', []): + self._update_voice_state(obj) + + def _sync(self, data): + if 'large' in data: + self.large = data['large'] + + for mdata in data.get('members', []): roles = [self.default_role] - for role_id in data['roles']: + for role_id in mdata['roles']: role = utils.find(lambda r: r.id == role_id, self.roles) if role is not None: roles.append(role) - data['roles'] = roles - member = Member(**data) + mdata['roles'] = roles + member = Member(**mdata) member.server = self self._add_member(member) - if 'owner_id' in guild: - self.owner_id = guild['owner_id'] - self.owner = self.get_member(self.owner_id) - - for presence in guild.get('presences', []): + for presence in data.get('presences', []): user_id = presence['user']['id'] member = self.get_member(user_id) if member is not None: @@ -210,17 +221,12 @@ class Server(Hashable): game = presence.get('game', {}) member.game = Game(**game) if game else None - if 'channels' in guild: - channels = guild['channels'] + if 'channels' in data: + channels = data['channels'] for c in channels: channel = Channel(server=self, **c) self._add_channel(channel) - afk_id = guild.get('afk_channel_id') - self.afk_channel = self.get_channel(afk_id) - - for obj in guild.get('voice_states', []): - self._update_voice_state(obj) @utils.cached_slot_property('_default_role') def default_role(self): diff --git a/discord/state.py b/discord/state.py index 17bc5d2ce..681da885d 100644 --- a/discord/state.py +++ b/discord/state.py @@ -49,11 +49,13 @@ log = logging.getLogger(__name__) ReadyState = namedtuple('ReadyState', ('launch', 'servers')) class ConnectionState: - def __init__(self, dispatch, chunker, max_messages, *, loop): + def __init__(self, dispatch, chunker, syncer, max_messages, *, loop): self.loop = loop self.max_messages = max_messages self.dispatch = dispatch self.chunker = chunker + self.syncer = syncer + self.is_bot = None self._listeners = [] self.clear() @@ -165,8 +167,9 @@ class ConnectionState: launch.set() yield from asyncio.sleep(2) - # get all the chunks servers = self._ready_state.servers + + # get all the chunks chunks = [] for server in servers: chunks.extend(self.chunks_needed(server)) @@ -194,9 +197,12 @@ class ConnectionState: servers = self._ready_state.servers for guild in guilds: server = self._add_server_from_data(guild) - if server.large: + if server.large or not self.is_bot: servers.append(server) + if not self.is_bot: + compat.create_task(self.syncer([s.id for s in self.servers]), loop=self.loop) + for pm in data.get('private_channels'): self._add_private_channel(PrivateChannel(id=pm['id'], user=User(**pm['recipient']))) @@ -427,6 +433,10 @@ class ConnectionState: else: self.dispatch('server_join', server) + def parse_guild_sync(self, data): + server = self._get_server(data.get('id')) + server._sync(data) + def parse_guild_update(self, data): server = self._get_server(data.get('id')) if server is not None: