10 changed files with 451 additions and 153 deletions
@ -0,0 +1,402 @@ |
|||||
|
# -*- coding: utf-8 -*- |
||||
|
|
||||
|
""" |
||||
|
The MIT License (MIT) |
||||
|
|
||||
|
Copyright (c) 2015-2016 Rapptz |
||||
|
|
||||
|
Permission is hereby granted, free of charge, to any person obtaining a |
||||
|
copy of this software and associated documentation files (the "Software"), |
||||
|
to deal in the Software without restriction, including without limitation |
||||
|
the rights to use, copy, modify, merge, publish, distribute, sublicense, |
||||
|
and/or sell copies of the Software, and to permit persons to whom the |
||||
|
Software is furnished to do so, subject to the following conditions: |
||||
|
|
||||
|
The above copyright notice and this permission notice shall be included in |
||||
|
all copies or substantial portions of the Software. |
||||
|
|
||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS |
||||
|
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
||||
|
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
||||
|
DEALINGS IN THE SOFTWARE. |
||||
|
""" |
||||
|
|
||||
|
import sys |
||||
|
import websockets |
||||
|
import asyncio |
||||
|
import aiohttp |
||||
|
from . import utils, endpoints, compat |
||||
|
from .enums import Status |
||||
|
from .game import Game |
||||
|
from .errors import GatewayNotFound, ConnectionClosed, InvalidArgument |
||||
|
import logging |
||||
|
import zlib, time, json |
||||
|
from collections import namedtuple |
||||
|
import threading |
||||
|
|
||||
|
log = logging.getLogger(__name__) |
||||
|
|
||||
|
__all__ = [ 'ReconnectWebSocket', 'get_gateway', 'DiscordWebSocket', |
||||
|
'KeepAliveHandler' ] |
||||
|
|
||||
|
class ReconnectWebSocket(Exception): |
||||
|
"""Signals to handle the RECONNECT opcode.""" |
||||
|
pass |
||||
|
|
||||
|
EventListener = namedtuple('EventListener', 'predicate event result future') |
||||
|
|
||||
|
class KeepAliveHandler(threading.Thread): |
||||
|
def __init__(self, *args, **kwargs): |
||||
|
ws = kwargs.pop('ws', None) |
||||
|
interval = kwargs.pop('interval', None) |
||||
|
threading.Thread.__init__(self, *args, **kwargs) |
||||
|
self.ws = ws |
||||
|
self.interval = interval |
||||
|
self.daemon = True |
||||
|
self._stop = threading.Event() |
||||
|
|
||||
|
def run(self): |
||||
|
while not self._stop.wait(self.interval): |
||||
|
data = self.get_payload() |
||||
|
msg = 'Keeping websocket alive with sequence {0[d]}'.format(data) |
||||
|
log.debug(msg) |
||||
|
coro = self.ws.send_as_json(data) |
||||
|
f = compat.run_coroutine_threadsafe(coro, loop=self.ws.loop) |
||||
|
try: |
||||
|
# block until sending is complete |
||||
|
f.result() |
||||
|
except Exception: |
||||
|
self.stop() |
||||
|
|
||||
|
def get_payload(self): |
||||
|
return { |
||||
|
'op': self.ws.HEARTBEAT, |
||||
|
'd': self.ws._connection.sequence |
||||
|
} |
||||
|
|
||||
|
def stop(self): |
||||
|
self._stop.set() |
||||
|
|
||||
|
|
||||
|
@asyncio.coroutine |
||||
|
def get_gateway(token, *, loop=None): |
||||
|
"""Returns the gateway URL for connecting to the WebSocket. |
||||
|
|
||||
|
Parameters |
||||
|
----------- |
||||
|
token : str |
||||
|
The discord authentication token. |
||||
|
loop |
||||
|
The event loop. |
||||
|
|
||||
|
Raises |
||||
|
------ |
||||
|
GatewayNotFound |
||||
|
When the gateway is not returned gracefully. |
||||
|
""" |
||||
|
headers = { |
||||
|
'authorization': token, |
||||
|
'content-type': 'application/json' |
||||
|
} |
||||
|
|
||||
|
with aiohttp.ClientSession(loop=loop) as session: |
||||
|
resp = yield from session.get(endpoints.GATEWAY, headers=headers) |
||||
|
if resp.status != 200: |
||||
|
yield from resp.release() |
||||
|
raise GatewayNotFound() |
||||
|
data = yield from resp.json() |
||||
|
return data.get('url') |
||||
|
|
||||
|
class DiscordWebSocket(websockets.client.WebSocketClientProtocol): |
||||
|
"""Implements a WebSocket for Discord's gateway v4. |
||||
|
|
||||
|
This is created through :func:`create_main_websocket`. Library |
||||
|
users should never create this manually. |
||||
|
|
||||
|
Attributes |
||||
|
----------- |
||||
|
DISPATCH |
||||
|
Receive only. Denotes an event to be sent to Discord, such as READY. |
||||
|
HEARTBEAT |
||||
|
When received tells Discord to keep the connection alive. |
||||
|
When sent asks if your connection is currently alive. |
||||
|
IDENTIFY |
||||
|
Send only. Starts a new session. |
||||
|
PRESENCE |
||||
|
Send only. Updates your presence. |
||||
|
VOICE_STATE |
||||
|
Send only. Starts a new connection to a voice server. |
||||
|
VOICE_PING |
||||
|
Send only. Checks ping time to a voice server, do not use. |
||||
|
RESUME |
||||
|
Send only. Resumes an existing connection. |
||||
|
RECONNECT |
||||
|
Receive only. Tells the client to reconnect to a new gateway. |
||||
|
REQUEST_MEMBERS |
||||
|
Send only. Asks for the full member list of a server. |
||||
|
INVALIDATE_SESSION |
||||
|
Receive only. Tells the client to invalidate the session and IDENTIFY |
||||
|
again. |
||||
|
gateway |
||||
|
The gateway we are currently connected to. |
||||
|
token |
||||
|
The authentication token for discord. |
||||
|
""" |
||||
|
|
||||
|
DISPATCH = 0 |
||||
|
HEARTBEAT = 1 |
||||
|
IDENTIFY = 2 |
||||
|
PRESENCE = 3 |
||||
|
VOICE_STATE = 4 |
||||
|
VOICE_PING = 5 |
||||
|
RESUME = 6 |
||||
|
RECONNECT = 7 |
||||
|
REQUEST_MEMBERS = 8 |
||||
|
INVALIDATE_SESSION = 9 |
||||
|
|
||||
|
def __init__(self, *args, **kwargs): |
||||
|
super().__init__(*args, max_size=None, **kwargs) |
||||
|
# an empty dispatcher to prevent crashes |
||||
|
self._dispatch = lambda *args: None |
||||
|
# generic event listeners |
||||
|
self._dispatch_listeners = [] |
||||
|
# the keep alive |
||||
|
self._keep_alive = None |
||||
|
|
||||
|
@classmethod |
||||
|
@asyncio.coroutine |
||||
|
def connect(cls, dispatch, *, token=None, connection=None, loop=None): |
||||
|
"""Creates a main websocket for Discord used for the client. |
||||
|
|
||||
|
Parameters |
||||
|
---------- |
||||
|
token : str |
||||
|
The token for Discord authentication. |
||||
|
connection |
||||
|
The ConnectionState for the client. |
||||
|
dispatch |
||||
|
The function that dispatches events. |
||||
|
loop |
||||
|
The event loop to use. |
||||
|
|
||||
|
Returns |
||||
|
------- |
||||
|
DiscordWebSocket |
||||
|
A websocket connected to Discord. |
||||
|
""" |
||||
|
|
||||
|
gateway = yield from get_gateway(token, loop=loop) |
||||
|
ws = yield from websockets.connect(gateway, loop=loop, klass=cls) |
||||
|
|
||||
|
# dynamically add attributes needed |
||||
|
ws.token = token |
||||
|
ws._connection = connection |
||||
|
ws._dispatch = dispatch |
||||
|
ws.gateway = gateway |
||||
|
|
||||
|
log.info('Created websocket connected to {}'.format(gateway)) |
||||
|
yield from ws.identify() |
||||
|
log.info('sent the identify payload to create the websocket') |
||||
|
return ws |
||||
|
|
||||
|
@classmethod |
||||
|
def from_client(cls, client): |
||||
|
"""Creates a main websocket for Discord from a :class:`Client`. |
||||
|
|
||||
|
This is for internal use only. |
||||
|
""" |
||||
|
return cls.connect(client.dispatch, token=client.token, |
||||
|
connection=client.connection, |
||||
|
loop=client.loop) |
||||
|
|
||||
|
def wait_for(self, event, predicate, result): |
||||
|
"""Waits for a DISPATCH'd event that meets the predicate. |
||||
|
|
||||
|
Parameters |
||||
|
----------- |
||||
|
event : str |
||||
|
The event name in all upper case to wait for. |
||||
|
predicate |
||||
|
A function that takes a data parameter to check for event |
||||
|
properties. The data parameter is the 'd' key in the JSON message. |
||||
|
result |
||||
|
A function that takes the same data parameter and executes to send |
||||
|
the result to the future. |
||||
|
|
||||
|
Returns |
||||
|
-------- |
||||
|
asyncio.Future |
||||
|
A future to wait for. |
||||
|
""" |
||||
|
|
||||
|
future = asyncio.Future(loop=self.loop) |
||||
|
entry = EventListener(event=event, predicate=predicate, result=result, future=future) |
||||
|
self._dispatch_listeners.append(entry) |
||||
|
return future |
||||
|
|
||||
|
@asyncio.coroutine |
||||
|
def identify(self): |
||||
|
"""Sends the IDENTIFY packet.""" |
||||
|
payload = { |
||||
|
'op': self.IDENTIFY, |
||||
|
'd': { |
||||
|
'token': self.token, |
||||
|
'properties': { |
||||
|
'$os': sys.platform, |
||||
|
'$browser': 'discord.py', |
||||
|
'$device': 'discord.py', |
||||
|
'$referrer': '', |
||||
|
'$referring_domain': '' |
||||
|
}, |
||||
|
'compress': True, |
||||
|
'large_threshold': 250, |
||||
|
'v': 3 |
||||
|
} |
||||
|
} |
||||
|
yield from self.send_as_json(payload) |
||||
|
|
||||
|
@asyncio.coroutine |
||||
|
def received_message(self, msg): |
||||
|
self._dispatch('socket_raw_receive', msg) |
||||
|
|
||||
|
if isinstance(msg, bytes): |
||||
|
msg = zlib.decompress(msg, 15, 10490000) # This is 10 MiB |
||||
|
msg = msg.decode('utf-8') |
||||
|
|
||||
|
msg = json.loads(msg) |
||||
|
|
||||
|
log.debug('WebSocket Event: {}'.format(msg)) |
||||
|
self._dispatch('socket_response', msg) |
||||
|
|
||||
|
op = msg.get('op') |
||||
|
data = msg.get('d') |
||||
|
|
||||
|
if 's' in msg: |
||||
|
self._connection.sequence = msg['s'] |
||||
|
|
||||
|
if op == self.RECONNECT: |
||||
|
# "reconnect" can only be handled by the Client |
||||
|
# so we terminate our connection and raise an |
||||
|
# internal exception signalling to reconnect. |
||||
|
yield from self.close() |
||||
|
raise ReconnectWebSocket() |
||||
|
|
||||
|
if op == self.INVALIDATE_SESSION: |
||||
|
self._connection.sequence = None |
||||
|
self._connection.session_id = None |
||||
|
return |
||||
|
|
||||
|
if op != self.DISPATCH: |
||||
|
log.info('Unhandled op {}'.format(op)) |
||||
|
return |
||||
|
|
||||
|
event = msg.get('t') |
||||
|
is_ready = event == 'READY' |
||||
|
|
||||
|
if is_ready: |
||||
|
self._connection.clear() |
||||
|
self._connection.sequence = msg['s'] |
||||
|
self._connection.session_id = data['session_id'] |
||||
|
|
||||
|
if is_ready or event == 'RESUMED': |
||||
|
interval = data['heartbeat_interval'] / 1000.0 |
||||
|
self._keep_alive = KeepAliveHandler(ws=self, interval=interval) |
||||
|
self._keep_alive.start() |
||||
|
|
||||
|
parser = 'parse_' + event.lower() |
||||
|
|
||||
|
try: |
||||
|
func = getattr(self._connection, parser) |
||||
|
except AttributeError: |
||||
|
log.info('Unhandled event {}'.format(event)) |
||||
|
else: |
||||
|
func(data) |
||||
|
|
||||
|
# remove the dispatched listeners |
||||
|
removed = [] |
||||
|
for index, entry in enumerate(self._dispatch_listeners): |
||||
|
if entry.event != event: |
||||
|
continue |
||||
|
|
||||
|
future = entry.future |
||||
|
if future.cancelled(): |
||||
|
removed.append(index) |
||||
|
|
||||
|
try: |
||||
|
valid = entry.predicate(data) |
||||
|
except Exception as e: |
||||
|
future.set_exception(e) |
||||
|
removed.append(index) |
||||
|
else: |
||||
|
if valid: |
||||
|
future.set_result(entry.result) |
||||
|
removed.append(index) |
||||
|
|
||||
|
for index in reversed(removed): |
||||
|
del self._dispatch_listeners[index] |
||||
|
|
||||
|
@asyncio.coroutine |
||||
|
def poll_event(self): |
||||
|
"""Polls for a DISPATCH event and handles the general gateway loop. |
||||
|
|
||||
|
Raises |
||||
|
------ |
||||
|
ConnectionClosed |
||||
|
The websocket connection was terminated for unhandled reasons. |
||||
|
""" |
||||
|
try: |
||||
|
msg = yield from self.recv() |
||||
|
yield from self.received_message(msg) |
||||
|
except websockets.exceptions.ConnectionClosed as e: |
||||
|
if e.code in (4008, 4009) or e.code in range(1001, 1015): |
||||
|
raise ReconnectWebSocket() from e |
||||
|
else: |
||||
|
raise ConnectionClosed(e) from e |
||||
|
|
||||
|
@asyncio.coroutine |
||||
|
def send(self, data): |
||||
|
self._dispatch('socket_raw_send', data) |
||||
|
yield from super().send(data) |
||||
|
|
||||
|
@asyncio.coroutine |
||||
|
def send_as_json(self, data): |
||||
|
yield from super().send(utils.to_json(data)) |
||||
|
|
||||
|
@asyncio.coroutine |
||||
|
def change_presence(self, *, game=None, idle=None): |
||||
|
if game is not None and not isinstance(game, Game): |
||||
|
raise InvalidArgument('game must be of Game or None') |
||||
|
|
||||
|
idle_since = None if idle == False else int(time.time() * 1000) |
||||
|
sent_game = game and {'name': game.name} |
||||
|
|
||||
|
payload = { |
||||
|
'op': self.PRESENCE, |
||||
|
'd': { |
||||
|
'game': sent_game, |
||||
|
'idle_since': idle_since |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
sent = utils.to_json(payload) |
||||
|
log.debug('Sending "{}" to change status'.format(sent)) |
||||
|
yield from self.send(sent) |
||||
|
|
||||
|
for server in self._connection.servers: |
||||
|
me = server.me |
||||
|
if me is None: |
||||
|
continue |
||||
|
|
||||
|
me.game = game |
||||
|
status = Status.idle if idle_since else Status.online |
||||
|
me.status = status |
||||
|
|
||||
|
@asyncio.coroutine |
||||
|
def close(self, code=1000, reason=''): |
||||
|
if self._keep_alive: |
||||
|
self._keep_alive.stop() |
||||
|
|
||||
|
yield from super().close(code, reason) |
@ -1,3 +1,3 @@ |
|||||
aiohttp>=0.21.0,<0.22.0 |
aiohttp>=0.21.0,<0.22.0 |
||||
websockets==2.7 |
websockets==3.1 |
||||
PyNaCl==1.0.1 |
PyNaCl==1.0.1 |
||||
|
Loading…
Reference in new issue