14 changed files with 654 additions and 38 deletions
@ -0,0 +1,160 @@ |
|||
from functools import partial |
|||
import uuid |
|||
|
|||
import json |
|||
import pickle |
|||
import six |
|||
|
|||
from .asyncio_manager import AsyncManager |
|||
|
|||
|
|||
class AsyncPubSubManager(AsyncManager): |
|||
"""Manage a client list attached to a pub/sub backend under asyncio. |
|||
|
|||
This is a base class that enables multiple servers to share the list of |
|||
clients, with the servers communicating events through a pub/sub backend. |
|||
The use of a pub/sub backend also allows any client connected to the |
|||
backend to emit events addressed to Socket.IO clients. |
|||
|
|||
The actual backends must be implemented by subclasses, this class only |
|||
provides a pub/sub generic framework for asyncio applications. |
|||
|
|||
:param channel: The channel name on which the server sends and receives |
|||
notifications. |
|||
""" |
|||
name = 'asyncpubsub' |
|||
|
|||
def __init__(self, channel='socketio', write_only=False): |
|||
super().__init__() |
|||
self.channel = channel |
|||
self.write_only = write_only |
|||
self.host_id = uuid.uuid4().hex |
|||
|
|||
def initialize(self): |
|||
super().initialize() |
|||
if not self.write_only: |
|||
self.thread = self.server.start_background_task(self._thread) |
|||
self.server.logger.info(self.name + ' backend initialized.') |
|||
|
|||
async def emit(self, event, data, namespace=None, room=None, skip_sid=None, |
|||
callback=None, **kwargs): |
|||
"""Emit a message to a single client, a room, or all the clients |
|||
connected to the namespace. |
|||
|
|||
This method takes care or propagating the message to all the servers |
|||
that are connected through the message queue. |
|||
|
|||
The parameters are the same as in :meth:`.Server.emit`. |
|||
|
|||
Note: this method is a coroutine. |
|||
""" |
|||
if kwargs.get('ignore_queue'): |
|||
return await super().emit( |
|||
event, data, namespace=namespace, room=room, skip_sid=skip_sid, |
|||
callback=callback) |
|||
namespace = namespace or '/' |
|||
if callback is not None: |
|||
if self.server is None: |
|||
raise RuntimeError('Callbacks can only be issued from the ' |
|||
'context of a server.') |
|||
if room is None: |
|||
raise ValueError('Cannot use callback without a room set.') |
|||
id = self._generate_ack_id(room, namespace, callback) |
|||
callback = (room, namespace, id) |
|||
else: |
|||
callback = None |
|||
await self._publish({'method': 'emit', 'event': event, 'data': data, |
|||
'namespace': namespace, 'room': room, |
|||
'skip_sid': skip_sid, 'callback': callback}) |
|||
|
|||
async def close_room(self, room, namespace=None): |
|||
await self._publish({'method': 'close_room', 'room': room, |
|||
'namespace': namespace or '/'}) |
|||
|
|||
async def _publish(self, data): |
|||
"""Publish a message on the Socket.IO channel. |
|||
|
|||
This method needs to be implemented by the different subclasses that |
|||
support pub/sub backends. |
|||
""" |
|||
raise NotImplementedError('This method must be implemented in a ' |
|||
'subclass.') # pragma: no cover |
|||
|
|||
async def _listen(self): |
|||
"""Return the next message published on the Socket.IO channel, |
|||
blocking until a message is available. |
|||
|
|||
This method needs to be implemented by the different subclasses that |
|||
support pub/sub backends. |
|||
""" |
|||
raise NotImplementedError('This method must be implemented in a ' |
|||
'subclass.') # pragma: no cover |
|||
|
|||
async def _handle_emit(self, message): |
|||
# Events with callbacks are very tricky to handle across hosts |
|||
# Here in the receiving end we set up a local callback that preserves |
|||
# the callback host and id from the sender |
|||
remote_callback = message.get('callback') |
|||
if remote_callback is not None and len(remote_callback) == 3: |
|||
callback = partial(self._return_callback, self.host_id, |
|||
*remote_callback) |
|||
else: |
|||
callback = None |
|||
await super().emit(message['event'], message['data'], |
|||
namespace=message.get('namespace'), |
|||
room=message.get('room'), |
|||
skip_sid=message.get('skip_sid'), |
|||
callback=callback) |
|||
|
|||
async def _handle_callback(self, message): |
|||
if self.host_id == message.get('host_id'): |
|||
try: |
|||
sid = message['sid'] |
|||
namespace = message['namespace'] |
|||
id = message['id'] |
|||
args = message['args'] |
|||
except KeyError: |
|||
return |
|||
await self.trigger_callback(sid, namespace, id, args) |
|||
|
|||
async def _return_callback(self, host_id, sid, namespace, callback_id, |
|||
*args): |
|||
# When an event callback is received, the callback is returned back |
|||
# the sender, which is identified by the host_id |
|||
await self._publish({'method': 'callback', 'host_id': host_id, |
|||
'sid': sid, 'namespace': namespace, |
|||
'id': callback_id, 'args': args}) |
|||
|
|||
async def _handle_close_room(self, message): |
|||
await super().close_room( |
|||
room=message.get('room'), namespace=message.get('namespace')) |
|||
|
|||
async def _thread(self): |
|||
while True: |
|||
try: |
|||
message = await self._listen() |
|||
except: |
|||
import traceback |
|||
traceback.print_exc() |
|||
break |
|||
data = None |
|||
if isinstance(message, dict): |
|||
data = message |
|||
else: |
|||
if isinstance(message, six.binary_type): # pragma: no cover |
|||
try: |
|||
data = pickle.loads(message) |
|||
except: |
|||
pass |
|||
if data is None: |
|||
try: |
|||
data = json.loads(message) |
|||
except: |
|||
pass |
|||
if data and 'method' in data: |
|||
if data['method'] == 'emit': |
|||
await self._handle_emit(data) |
|||
elif data['method'] == 'callback': |
|||
await self._handle_callback(data) |
|||
elif data['method'] == 'close_room': |
|||
await self._handle_close_room(data) |
@ -0,0 +1,78 @@ |
|||
import pickle |
|||
from urllib.parse import urlparse |
|||
|
|||
try: |
|||
import aioredis |
|||
except ImportError: |
|||
aioredis = None |
|||
|
|||
from .asyncio_pubsub_manager import AsyncPubSubManager |
|||
|
|||
|
|||
def _parse_redis_url(url): |
|||
p = urlparse(url) |
|||
if p.scheme != 'redis': |
|||
raise ValueError('Invalid redis url') |
|||
if ':' in p.netloc: |
|||
host, port = p.netloc.split(':') |
|||
port = int(port) |
|||
else: |
|||
host = p.netloc or 'localhost' |
|||
port = 6379 |
|||
if p.path: |
|||
db = int(p.path[1:]) |
|||
else: |
|||
db = 0 |
|||
if not host: |
|||
raise ValueError('Invalid redis hostname') |
|||
return host, port, db |
|||
|
|||
|
|||
class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover |
|||
"""Redis based client manager for asyncio servers. |
|||
|
|||
This class implements a Redis backend for event sharing across multiple |
|||
processes. Only kept here as one more example of how to build a custom |
|||
backend, since the kombu backend is perfectly adequate to support a Redis |
|||
message queue. |
|||
|
|||
To use a Redis backend, initialize the :class:`Server` instance as |
|||
follows:: |
|||
|
|||
server = socketio.Server(client_manager=socketio.AsyncRedisManager( |
|||
'redis://hostname:port/0')) |
|||
|
|||
:param url: The connection URL for the Redis server. For a default Redis |
|||
store running on the same host, use ``redis://``. |
|||
:param channel: The channel name on which the server sends and receives |
|||
notifications. Must be the same in all the servers. |
|||
:param write_only: If set ot ``True``, only initialize to emit events. The |
|||
default of ``False`` initializes the class for emitting |
|||
and receiving. |
|||
""" |
|||
name = 'aioredis' |
|||
|
|||
def __init__(self, url='redis://localhost:6379/0', channel='socketio', |
|||
write_only=False): |
|||
if aioredis is None: |
|||
raise RuntimeError('Redis package is not installed ' |
|||
'(Run "pip install aioredis" in your ' |
|||
'virtualenv).') |
|||
self.host, self.port, self.db = _parse_redis_url(url) |
|||
self.pub = None |
|||
self.sub = None |
|||
super().__init__(channel=channel, write_only=write_only) |
|||
|
|||
async def _publish(self, data): |
|||
if self.pub is None: |
|||
self.pub = await aioredis.create_redis((self.host, self.port), |
|||
db=self.db) |
|||
return await self.pub.publish(self.channel, pickle.dumps(data)) |
|||
|
|||
async def _listen(self): |
|||
if self.sub is None: |
|||
self.sub = await aioredis.create_redis((self.host, self.port), |
|||
db=self.db) |
|||
self.ch = (await self.sub.subscribe(self.channel))[0] |
|||
while True: |
|||
return await self.ch.get() |
@ -0,0 +1,268 @@ |
|||
import functools |
|||
import sys |
|||
import unittest |
|||
|
|||
import six |
|||
if six.PY3: |
|||
from unittest import mock |
|||
else: |
|||
import mock |
|||
|
|||
if sys.version_info >= (3, 5): |
|||
import asyncio |
|||
from asyncio import coroutine |
|||
from socketio import asyncio_manager |
|||
from socketio import asyncio_pubsub_manager |
|||
else: |
|||
# mock coroutine so that Python 2 doesn't complain |
|||
def coroutine(f): |
|||
return f |
|||
|
|||
|
|||
def AsyncMock(*args, **kwargs): |
|||
"""Return a mock asynchronous function.""" |
|||
m = mock.MagicMock(*args, **kwargs) |
|||
|
|||
@coroutine |
|||
def mock_coro(*args, **kwargs): |
|||
return m(*args, **kwargs) |
|||
|
|||
mock_coro.mock = m |
|||
return mock_coro |
|||
|
|||
|
|||
def _run(coro): |
|||
"""Run the given coroutine.""" |
|||
return asyncio.get_event_loop().run_until_complete(coro) |
|||
|
|||
|
|||
@unittest.skipIf(sys.version_info < (3, 5), 'only for Python 3.5+') |
|||
class TestAsyncPubSubManager(unittest.TestCase): |
|||
def setUp(self): |
|||
mock_server = mock.MagicMock() |
|||
mock_server._emit_internal = AsyncMock() |
|||
self.pm = asyncio_pubsub_manager.AsyncPubSubManager() |
|||
self.pm._publish = AsyncMock() |
|||
self.pm.set_server(mock_server) |
|||
self.pm.initialize() |
|||
|
|||
def test_default_init(self): |
|||
self.assertEqual(self.pm.channel, 'socketio') |
|||
self.assertEqual(len(self.pm.host_id), 32) |
|||
self.pm.server.start_background_task.assert_called_once_with( |
|||
self.pm._thread) |
|||
|
|||
def test_custom_init(self): |
|||
pubsub = asyncio_pubsub_manager.AsyncPubSubManager(channel='foo') |
|||
self.assertEqual(pubsub.channel, 'foo') |
|||
self.assertEqual(len(pubsub.host_id), 32) |
|||
|
|||
def test_write_only_init(self): |
|||
mock_server = mock.MagicMock() |
|||
pm = asyncio_pubsub_manager.AsyncPubSubManager(write_only=True) |
|||
pm.set_server(mock_server) |
|||
pm.initialize() |
|||
self.assertEqual(pm.channel, 'socketio') |
|||
self.assertEqual(len(pm.host_id), 32) |
|||
self.assertEqual(pm.server.start_background_task.call_count, 0) |
|||
|
|||
def test_emit(self): |
|||
_run(self.pm.emit('foo', 'bar')) |
|||
self.pm._publish.mock.assert_called_once_with( |
|||
{'method': 'emit', 'event': 'foo', 'data': 'bar', |
|||
'namespace': '/', 'room': None, 'skip_sid': None, |
|||
'callback': None}) |
|||
|
|||
def test_emit_with_namespace(self): |
|||
_run(self.pm.emit('foo', 'bar', namespace='/baz')) |
|||
self.pm._publish.mock.assert_called_once_with( |
|||
{'method': 'emit', 'event': 'foo', 'data': 'bar', |
|||
'namespace': '/baz', 'room': None, 'skip_sid': None, |
|||
'callback': None}) |
|||
|
|||
def test_emit_with_room(self): |
|||
_run(self.pm.emit('foo', 'bar', room='baz')) |
|||
self.pm._publish.mock.assert_called_once_with( |
|||
{'method': 'emit', 'event': 'foo', 'data': 'bar', |
|||
'namespace': '/', 'room': 'baz', 'skip_sid': None, |
|||
'callback': None}) |
|||
|
|||
def test_emit_with_skip_sid(self): |
|||
_run(self.pm.emit('foo', 'bar', skip_sid='baz')) |
|||
self.pm._publish.mock.assert_called_once_with( |
|||
{'method': 'emit', 'event': 'foo', 'data': 'bar', |
|||
'namespace': '/', 'room': None, 'skip_sid': 'baz', |
|||
'callback': None}) |
|||
|
|||
def test_emit_with_callback(self): |
|||
with mock.patch.object(self.pm, '_generate_ack_id', |
|||
return_value='123'): |
|||
_run(self.pm.emit('foo', 'bar', room='baz', callback='cb')) |
|||
self.pm._publish.mock.assert_called_once_with( |
|||
{'method': 'emit', 'event': 'foo', 'data': 'bar', |
|||
'namespace': '/', 'room': 'baz', 'skip_sid': None, |
|||
'callback': ('baz', '/', '123')}) |
|||
|
|||
def test_emit_with_callback_without_server(self): |
|||
standalone_pm = asyncio_pubsub_manager.AsyncPubSubManager() |
|||
self.assertRaises(RuntimeError, _run, |
|||
standalone_pm.emit('foo', 'bar', callback='cb')) |
|||
|
|||
def test_emit_with_callback_missing_room(self): |
|||
with mock.patch.object(self.pm, '_generate_ack_id', |
|||
return_value='123'): |
|||
self.assertRaises(ValueError, _run, |
|||
self.pm.emit('foo', 'bar', callback='cb')) |
|||
|
|||
def test_emit_with_ignore_queue(self): |
|||
self.pm.connect('123', '/') |
|||
_run(self.pm.emit('foo', 'bar', room='123', namespace='/', |
|||
ignore_queue=True)) |
|||
self.pm._publish.mock.assert_not_called() |
|||
self.pm.server._emit_internal.mock.assert_called_once_with( |
|||
'123', 'foo', 'bar', '/', None) |
|||
|
|||
def test_close_room(self): |
|||
_run(self.pm.close_room('foo')) |
|||
self.pm._publish.mock.assert_called_once_with( |
|||
{'method': 'close_room', 'room': 'foo', 'namespace': '/'}) |
|||
|
|||
def test_close_room_with_namespace(self): |
|||
_run(self.pm.close_room('foo', '/bar')) |
|||
self.pm._publish.mock.assert_called_once_with( |
|||
{'method': 'close_room', 'room': 'foo', 'namespace': '/bar'}) |
|||
|
|||
def test_handle_emit(self): |
|||
with mock.patch.object(asyncio_manager.AsyncManager, 'emit', |
|||
new=AsyncMock()) as super_emit: |
|||
_run(self.pm._handle_emit({'event': 'foo', 'data': 'bar'})) |
|||
super_emit.mock.assert_called_once_with( |
|||
self.pm, 'foo', 'bar', namespace=None, room=None, |
|||
skip_sid=None, callback=None) |
|||
|
|||
def test_handle_emit_with_namespace(self): |
|||
with mock.patch.object(asyncio_manager.AsyncManager, 'emit', |
|||
new=AsyncMock()) as super_emit: |
|||
_run(self.pm._handle_emit({'event': 'foo', 'data': 'bar', |
|||
'namespace': '/baz'})) |
|||
super_emit.mock.assert_called_once_with( |
|||
self.pm, 'foo', 'bar', namespace='/baz', room=None, |
|||
skip_sid=None, callback=None) |
|||
|
|||
def test_handle_emiti_with_room(self): |
|||
with mock.patch.object(asyncio_manager.AsyncManager, 'emit', |
|||
new=AsyncMock()) as super_emit: |
|||
_run(self.pm._handle_emit({'event': 'foo', 'data': 'bar', |
|||
'room': 'baz'})) |
|||
super_emit.mock.assert_called_once_with( |
|||
self.pm, 'foo', 'bar', namespace=None, room='baz', |
|||
skip_sid=None, callback=None) |
|||
|
|||
def test_handle_emit_with_skip_sid(self): |
|||
with mock.patch.object(asyncio_manager.AsyncManager, 'emit', |
|||
new=AsyncMock()) as super_emit: |
|||
_run(self.pm._handle_emit({'event': 'foo', 'data': 'bar', |
|||
'skip_sid': '123'})) |
|||
super_emit.mock.assert_called_once_with( |
|||
self.pm, 'foo', 'bar', namespace=None, room=None, |
|||
skip_sid='123', callback=None) |
|||
|
|||
def test_handle_emit_with_callback(self): |
|||
host_id = self.pm.host_id |
|||
with mock.patch.object(asyncio_manager.AsyncManager, 'emit', |
|||
new=AsyncMock()) as super_emit: |
|||
_run(self.pm._handle_emit({'event': 'foo', 'data': 'bar', |
|||
'namespace': '/baz', |
|||
'callback': ('sid', '/baz', 123)})) |
|||
self.assertEqual(super_emit.mock.call_count, 1) |
|||
self.assertEqual(super_emit.mock.call_args[0], |
|||
(self.pm, 'foo', 'bar')) |
|||
self.assertEqual(super_emit.mock.call_args[1]['namespace'], '/baz') |
|||
self.assertIsNone(super_emit.mock.call_args[1]['room']) |
|||
self.assertIsNone(super_emit.mock.call_args[1]['skip_sid']) |
|||
self.assertIsInstance(super_emit.mock.call_args[1]['callback'], |
|||
functools.partial) |
|||
_run(super_emit.mock.call_args[1]['callback']('one', 2, 'three')) |
|||
self.pm._publish.mock.assert_called_once_with( |
|||
{'method': 'callback', 'host_id': host_id, 'sid': 'sid', |
|||
'namespace': '/baz', 'id': 123, 'args': ('one', 2, 'three')}) |
|||
|
|||
def test_handle_callback(self): |
|||
host_id = self.pm.host_id |
|||
with mock.patch.object(self.pm, 'trigger_callback', |
|||
new=AsyncMock()) as trigger: |
|||
_run(self.pm._handle_callback({'method': 'callback', |
|||
'host_id': host_id, 'sid': 'sid', |
|||
'namespace': '/', 'id': 123, |
|||
'args': ('one', 2)})) |
|||
trigger.mock.assert_called_once_with('sid', '/', 123, ('one', 2)) |
|||
|
|||
def test_handle_callback_bad_host_id(self): |
|||
with mock.patch.object(self.pm, 'trigger_callback', |
|||
new=AsyncMock()) as trigger: |
|||
_run(self.pm._handle_callback({'method': 'callback', |
|||
'host_id': 'bad', 'sid': 'sid', |
|||
'namespace': '/', 'id': 123, |
|||
'args': ('one', 2)})) |
|||
self.assertEqual(trigger.mock.call_count, 0) |
|||
|
|||
def test_handle_callback_missing_args(self): |
|||
host_id = self.pm.host_id |
|||
with mock.patch.object(self.pm, 'trigger_callback', |
|||
new=AsyncMock()) as trigger: |
|||
_run(self.pm._handle_callback({'method': 'callback', |
|||
'host_id': host_id, 'sid': 'sid', |
|||
'namespace': '/', 'id': 123})) |
|||
_run(self.pm._handle_callback({'method': 'callback', |
|||
'host_id': host_id, 'sid': 'sid', |
|||
'namespace': '/'})) |
|||
_run(self.pm._handle_callback({'method': 'callback', |
|||
'host_id': host_id, 'sid': 'sid'})) |
|||
_run(self.pm._handle_callback({'method': 'callback', |
|||
'host_id': host_id})) |
|||
self.assertEqual(trigger.mock.call_count, 0) |
|||
|
|||
def test_handle_close_room(self): |
|||
with mock.patch.object(asyncio_manager.AsyncManager, 'close_room', |
|||
new=AsyncMock()) as super_close_room: |
|||
_run(self.pm._handle_close_room({'method': 'close_room', |
|||
'room': 'foo'})) |
|||
super_close_room.mock.assert_called_once_with( |
|||
self.pm, room='foo', namespace=None) |
|||
|
|||
def test_handle_close_room_with_namespace(self): |
|||
with mock.patch.object(asyncio_manager.AsyncManager, 'close_room', |
|||
new=AsyncMock()) as super_close_room: |
|||
_run(self.pm._handle_close_room({'method': 'close_room', |
|||
'room': 'foo', |
|||
'namespace': '/bar'})) |
|||
super_close_room.mock.assert_called_once_with( |
|||
self.pm, room='foo', namespace='/bar') |
|||
|
|||
def test_background_thread(self): |
|||
self.pm._handle_emit = AsyncMock() |
|||
self.pm._handle_callback = AsyncMock() |
|||
self.pm._handle_close_room = AsyncMock() |
|||
|
|||
def messages(): |
|||
import pickle |
|||
yield {'method': 'emit', 'value': 'foo'} |
|||
yield {'missing': 'method'} |
|||
yield '{"method": "callback", "value": "bar"}' |
|||
yield {'method': 'bogus'} |
|||
yield pickle.dumps({'method': 'close_room', 'value': 'baz'}) |
|||
yield 'bad json' |
|||
yield b'bad pickled' |
|||
|
|||
self.pm._listen = AsyncMock(side_effect=list(messages())) |
|||
try: |
|||
_run(self.pm._thread()) |
|||
except StopIteration: |
|||
pass |
|||
|
|||
self.pm._handle_emit.mock.assert_called_once_with( |
|||
{'method': 'emit', 'value': 'foo'}) |
|||
self.pm._handle_callback.mock.assert_called_once_with( |
|||
{'method': 'callback', 'value': 'bar'}) |
|||
self.pm._handle_close_room.mock.assert_called_once_with( |
|||
{'method': 'close_room', 'value': 'baz'}) |
@ -0,0 +1,43 @@ |
|||
import sys |
|||
import unittest |
|||
|
|||
if sys.version_info >= (3, 5): |
|||
from socketio import asyncio_redis_manager |
|||
|
|||
|
|||
@unittest.skipIf(sys.version_info < (3, 5), 'only for Python 3.5+') |
|||
class TestAsyncRedisManager(unittest.TestCase): |
|||
def test_default_url(self): |
|||
self.assertEqual(asyncio_redis_manager._parse_redis_url('redis://'), |
|||
('localhost', 6379, 0)) |
|||
|
|||
def test_only_host_url(self): |
|||
self.assertEqual( |
|||
asyncio_redis_manager._parse_redis_url('redis://redis.host'), |
|||
('redis.host', 6379, 0)) |
|||
|
|||
def test_no_db_url(self): |
|||
self.assertEqual( |
|||
asyncio_redis_manager._parse_redis_url('redis://redis.host:123/1'), |
|||
('redis.host', 123, 1)) |
|||
|
|||
def test_no_port_url(self): |
|||
self.assertEqual( |
|||
asyncio_redis_manager._parse_redis_url('redis://redis.host/1'), |
|||
('redis.host', 6379, 1)) |
|||
|
|||
def test_no_host_url(self): |
|||
self.assertRaises(ValueError, asyncio_redis_manager._parse_redis_url, |
|||
'redis://:123/1') |
|||
|
|||
def test_bad_port_url(self): |
|||
self.assertRaises(ValueError, asyncio_redis_manager._parse_redis_url, |
|||
'redis://localhost:abc/1') |
|||
|
|||
def test_bad_db_url(self): |
|||
self.assertRaises(ValueError, asyncio_redis_manager._parse_redis_url, |
|||
'redis://localhost:abc/z') |
|||
|
|||
def test_bad_scheme_url(self): |
|||
self.assertRaises(ValueError, asyncio_redis_manager._parse_redis_url, |
|||
'http://redis.host:123/1') |
Loading…
Reference in new issue