From e48ad0ab8c2c46e01f8f78116856f124d42b59c7 Mon Sep 17 00:00:00 2001
From: Philipp Reitter
Date: Sat, 9 Dec 2017 13:28:48 +0100
Subject: [PATCH] Properly handle callbacks in multi-host configurations in
asyncio pubsub manager
---
socketio/asyncio_pubsub_manager.py | 6 ++++--
tests/test_asyncio_pubsub_manager.py | 15 ++++++++-------
2 files changed, 12 insertions(+), 9 deletions(-)
diff --git a/socketio/asyncio_pubsub_manager.py b/socketio/asyncio_pubsub_manager.py
index 8442cd1..578e734 100644
--- a/socketio/asyncio_pubsub_manager.py
+++ b/socketio/asyncio_pubsub_manager.py
@@ -65,7 +65,8 @@ class AsyncPubSubManager(AsyncManager):
callback = None
await self._publish({'method': 'emit', 'event': event, 'data': data,
'namespace': namespace, 'room': room,
- 'skip_sid': skip_sid, 'callback': callback})
+ 'skip_sid': skip_sid, 'callback': callback,
+ 'host_id': self.host_id})
async def close_room(self, room, namespace=None):
await self._publish({'method': 'close_room', 'room': room,
@@ -95,8 +96,9 @@ class AsyncPubSubManager(AsyncManager):
# Here in the receiving end we set up a local callback that preserves
# the callback host and id from the sender
remote_callback = message.get('callback')
+ remote_host_id = message.get('host_id')
if remote_callback is not None and len(remote_callback) == 3:
- callback = partial(self._return_callback, self.host_id,
+ callback = partial(self._return_callback, remote_host_id,
*remote_callback)
else:
callback = None
diff --git a/tests/test_asyncio_pubsub_manager.py b/tests/test_asyncio_pubsub_manager.py
index 2f556e6..19f4890 100644
--- a/tests/test_asyncio_pubsub_manager.py
+++ b/tests/test_asyncio_pubsub_manager.py
@@ -44,11 +44,11 @@ class TestAsyncPubSubManager(unittest.TestCase):
self.pm = asyncio_pubsub_manager.AsyncPubSubManager()
self.pm._publish = AsyncMock()
self.pm.set_server(mock_server)
+ self.pm.host_id = '123456'
self.pm.initialize()
def test_default_init(self):
self.assertEqual(self.pm.channel, 'socketio')
- self.assertEqual(len(self.pm.host_id), 32)
self.pm.server.start_background_task.assert_called_once_with(
self.pm._thread)
@@ -71,28 +71,28 @@ class TestAsyncPubSubManager(unittest.TestCase):
self.pm._publish.mock.assert_called_once_with(
{'method': 'emit', 'event': 'foo', 'data': 'bar',
'namespace': '/', 'room': None, 'skip_sid': None,
- 'callback': None})
+ 'callback': None, 'host_id': '123456'})
def test_emit_with_namespace(self):
_run(self.pm.emit('foo', 'bar', namespace='/baz'))
self.pm._publish.mock.assert_called_once_with(
{'method': 'emit', 'event': 'foo', 'data': 'bar',
'namespace': '/baz', 'room': None, 'skip_sid': None,
- 'callback': None})
+ 'callback': None, 'host_id': '123456'})
def test_emit_with_room(self):
_run(self.pm.emit('foo', 'bar', room='baz'))
self.pm._publish.mock.assert_called_once_with(
{'method': 'emit', 'event': 'foo', 'data': 'bar',
'namespace': '/', 'room': 'baz', 'skip_sid': None,
- 'callback': None})
+ 'callback': None, 'host_id': '123456'})
def test_emit_with_skip_sid(self):
_run(self.pm.emit('foo', 'bar', skip_sid='baz'))
self.pm._publish.mock.assert_called_once_with(
{'method': 'emit', 'event': 'foo', 'data': 'bar',
'namespace': '/', 'room': None, 'skip_sid': 'baz',
- 'callback': None})
+ 'callback': None, 'host_id': '123456'})
def test_emit_with_callback(self):
with mock.patch.object(self.pm, '_generate_ack_id',
@@ -101,7 +101,7 @@ class TestAsyncPubSubManager(unittest.TestCase):
self.pm._publish.mock.assert_called_once_with(
{'method': 'emit', 'event': 'foo', 'data': 'bar',
'namespace': '/', 'room': 'baz', 'skip_sid': None,
- 'callback': ('baz', '/', '123')})
+ 'callback': ('baz', '/', '123'), 'host_id': '123456'})
def test_emit_with_callback_without_server(self):
standalone_pm = asyncio_pubsub_manager.AsyncPubSubManager()
@@ -173,7 +173,8 @@ class TestAsyncPubSubManager(unittest.TestCase):
new=AsyncMock()) as super_emit:
_run(self.pm._handle_emit({'event': 'foo', 'data': 'bar',
'namespace': '/baz',
- 'callback': ('sid', '/baz', 123)}))
+ 'callback': ('sid', '/baz', 123),
+ 'host_id': host_id}))
self.assertEqual(super_emit.mock.call_count, 1)
self.assertEqual(super_emit.mock.call_args[0],
(self.pm, 'foo', 'bar'))