diff --git a/socketio/asyncio_manager.py b/socketio/asyncio_manager.py index 5322aa9..08b02d7 100644 --- a/socketio/asyncio_manager.py +++ b/socketio/asyncio_manager.py @@ -48,10 +48,9 @@ class AsyncManager(BaseManager): else: del self.callbacks[sid][namespace][id] if callback is not None: - if asyncio.iscoroutinefunction(callback) is True: + ret = callback(*data) + if asyncio.iscoroutine(ret): try: - await callback(*data) + await ret except asyncio.CancelledError: # pragma: no cover pass - else: - callback(*data) diff --git a/socketio/pubsub_manager.py b/socketio/pubsub_manager.py index 4810728..afbe276 100644 --- a/socketio/pubsub_manager.py +++ b/socketio/pubsub_manager.py @@ -63,7 +63,8 @@ class PubSubManager(BaseManager): callback = None self._publish({'method': 'emit', 'event': event, 'data': data, 'namespace': namespace, 'room': room, - 'skip_sid': skip_sid, 'callback': callback}) + 'skip_sid': skip_sid, 'callback': callback, + 'host_id': self.host_id}) def close_room(self, room, namespace=None): self._publish({'method': 'close_room', 'room': room, @@ -93,8 +94,9 @@ class PubSubManager(BaseManager): # Here in the receiving end we set up a local callback that preserves # the callback host and id from the sender remote_callback = message.get('callback') + remote_host_id = message.get('host_id') if remote_callback is not None and len(remote_callback) == 3: - callback = partial(self._return_callback, self.host_id, + callback = partial(self._return_callback, remote_host_id, *remote_callback) else: callback = None diff --git a/tests/test_pubsub_manager.py b/tests/test_pubsub_manager.py index 684dedb..0430461 100644 --- a/tests/test_pubsub_manager.py +++ b/tests/test_pubsub_manager.py @@ -17,11 +17,11 @@ class TestBaseManager(unittest.TestCase): self.pm = pubsub_manager.PubSubManager() self.pm._publish = mock.MagicMock() self.pm.set_server(mock_server) + self.pm.host_id = '123456' self.pm.initialize() def test_default_init(self): self.assertEqual(self.pm.channel, 'socketio') - self.assertEqual(len(self.pm.host_id), 32) self.pm.server.start_background_task.assert_called_once_with( self.pm._thread) @@ -44,28 +44,28 @@ class TestBaseManager(unittest.TestCase): self.pm._publish.assert_called_once_with( {'method': 'emit', 'event': 'foo', 'data': 'bar', 'namespace': '/', 'room': None, 'skip_sid': None, - 'callback': None}) + 'callback': None, 'host_id': '123456'}) def test_emit_with_namespace(self): self.pm.emit('foo', 'bar', namespace='/baz') self.pm._publish.assert_called_once_with( {'method': 'emit', 'event': 'foo', 'data': 'bar', 'namespace': '/baz', 'room': None, 'skip_sid': None, - 'callback': None}) + 'callback': None, 'host_id': '123456'}) def test_emit_with_room(self): self.pm.emit('foo', 'bar', room='baz') self.pm._publish.assert_called_once_with( {'method': 'emit', 'event': 'foo', 'data': 'bar', 'namespace': '/', 'room': 'baz', 'skip_sid': None, - 'callback': None}) + 'callback': None, 'host_id': '123456'}) def test_emit_with_skip_sid(self): self.pm.emit('foo', 'bar', skip_sid='baz') self.pm._publish.assert_called_once_with( {'method': 'emit', 'event': 'foo', 'data': 'bar', 'namespace': '/', 'room': None, 'skip_sid': 'baz', - 'callback': None}) + 'callback': None, 'host_id': '123456'}) def test_emit_with_callback(self): with mock.patch.object(self.pm, '_generate_ack_id', @@ -74,7 +74,7 @@ class TestBaseManager(unittest.TestCase): self.pm._publish.assert_called_once_with( {'method': 'emit', 'event': 'foo', 'data': 'bar', 'namespace': '/', 'room': 'baz', 'skip_sid': None, - 'callback': ('baz', '/', '123')}) + 'callback': ('baz', '/', '123'), 'host_id': '123456'}) def test_emit_with_callback_without_server(self): standalone_pm = pubsub_manager.PubSubManager() @@ -141,7 +141,8 @@ class TestBaseManager(unittest.TestCase): with mock.patch.object(base_manager.BaseManager, 'emit') as super_emit: self.pm._handle_emit({'event': 'foo', 'data': 'bar', 'namespace': '/baz', - 'callback': ('sid', '/baz', 123)}) + 'callback': ('sid', '/baz', 123), + 'host_id': host_id}) self.assertEqual(super_emit.call_count, 1) self.assertEqual(super_emit.call_args[0], ('foo', 'bar')) self.assertEqual(super_emit.call_args[1]['namespace'], '/baz')