Browse Source

Properly handle callbacks in multi-host configurations

Fixes #150
pull/152/head
Miguel Grinberg 7 years ago
parent
commit
8d7059a1a2
No known key found for this signature in database GPG Key ID: 36848B262DF5F06C
  1. 7
      socketio/asyncio_manager.py
  2. 6
      socketio/pubsub_manager.py
  3. 15
      tests/test_pubsub_manager.py

7
socketio/asyncio_manager.py

@ -48,10 +48,9 @@ class AsyncManager(BaseManager):
else: else:
del self.callbacks[sid][namespace][id] del self.callbacks[sid][namespace][id]
if callback is not None: if callback is not None:
if asyncio.iscoroutinefunction(callback) is True: ret = callback(*data)
if asyncio.iscoroutine(ret):
try: try:
await callback(*data) await ret
except asyncio.CancelledError: # pragma: no cover except asyncio.CancelledError: # pragma: no cover
pass pass
else:
callback(*data)

6
socketio/pubsub_manager.py

@ -63,7 +63,8 @@ class PubSubManager(BaseManager):
callback = None callback = None
self._publish({'method': 'emit', 'event': event, 'data': data, self._publish({'method': 'emit', 'event': event, 'data': data,
'namespace': namespace, 'room': room, 'namespace': namespace, 'room': room,
'skip_sid': skip_sid, 'callback': callback}) 'skip_sid': skip_sid, 'callback': callback,
'host_id': self.host_id})
def close_room(self, room, namespace=None): def close_room(self, room, namespace=None):
self._publish({'method': 'close_room', 'room': room, self._publish({'method': 'close_room', 'room': room,
@ -93,8 +94,9 @@ class PubSubManager(BaseManager):
# Here in the receiving end we set up a local callback that preserves # Here in the receiving end we set up a local callback that preserves
# the callback host and id from the sender # the callback host and id from the sender
remote_callback = message.get('callback') remote_callback = message.get('callback')
remote_host_id = message.get('host_id')
if remote_callback is not None and len(remote_callback) == 3: if remote_callback is not None and len(remote_callback) == 3:
callback = partial(self._return_callback, self.host_id, callback = partial(self._return_callback, remote_host_id,
*remote_callback) *remote_callback)
else: else:
callback = None callback = None

15
tests/test_pubsub_manager.py

@ -17,11 +17,11 @@ class TestBaseManager(unittest.TestCase):
self.pm = pubsub_manager.PubSubManager() self.pm = pubsub_manager.PubSubManager()
self.pm._publish = mock.MagicMock() self.pm._publish = mock.MagicMock()
self.pm.set_server(mock_server) self.pm.set_server(mock_server)
self.pm.host_id = '123456'
self.pm.initialize() self.pm.initialize()
def test_default_init(self): def test_default_init(self):
self.assertEqual(self.pm.channel, 'socketio') self.assertEqual(self.pm.channel, 'socketio')
self.assertEqual(len(self.pm.host_id), 32)
self.pm.server.start_background_task.assert_called_once_with( self.pm.server.start_background_task.assert_called_once_with(
self.pm._thread) self.pm._thread)
@ -44,28 +44,28 @@ class TestBaseManager(unittest.TestCase):
self.pm._publish.assert_called_once_with( self.pm._publish.assert_called_once_with(
{'method': 'emit', 'event': 'foo', 'data': 'bar', {'method': 'emit', 'event': 'foo', 'data': 'bar',
'namespace': '/', 'room': None, 'skip_sid': None, 'namespace': '/', 'room': None, 'skip_sid': None,
'callback': None}) 'callback': None, 'host_id': '123456'})
def test_emit_with_namespace(self): def test_emit_with_namespace(self):
self.pm.emit('foo', 'bar', namespace='/baz') self.pm.emit('foo', 'bar', namespace='/baz')
self.pm._publish.assert_called_once_with( self.pm._publish.assert_called_once_with(
{'method': 'emit', 'event': 'foo', 'data': 'bar', {'method': 'emit', 'event': 'foo', 'data': 'bar',
'namespace': '/baz', 'room': None, 'skip_sid': None, 'namespace': '/baz', 'room': None, 'skip_sid': None,
'callback': None}) 'callback': None, 'host_id': '123456'})
def test_emit_with_room(self): def test_emit_with_room(self):
self.pm.emit('foo', 'bar', room='baz') self.pm.emit('foo', 'bar', room='baz')
self.pm._publish.assert_called_once_with( self.pm._publish.assert_called_once_with(
{'method': 'emit', 'event': 'foo', 'data': 'bar', {'method': 'emit', 'event': 'foo', 'data': 'bar',
'namespace': '/', 'room': 'baz', 'skip_sid': None, 'namespace': '/', 'room': 'baz', 'skip_sid': None,
'callback': None}) 'callback': None, 'host_id': '123456'})
def test_emit_with_skip_sid(self): def test_emit_with_skip_sid(self):
self.pm.emit('foo', 'bar', skip_sid='baz') self.pm.emit('foo', 'bar', skip_sid='baz')
self.pm._publish.assert_called_once_with( self.pm._publish.assert_called_once_with(
{'method': 'emit', 'event': 'foo', 'data': 'bar', {'method': 'emit', 'event': 'foo', 'data': 'bar',
'namespace': '/', 'room': None, 'skip_sid': 'baz', 'namespace': '/', 'room': None, 'skip_sid': 'baz',
'callback': None}) 'callback': None, 'host_id': '123456'})
def test_emit_with_callback(self): def test_emit_with_callback(self):
with mock.patch.object(self.pm, '_generate_ack_id', with mock.patch.object(self.pm, '_generate_ack_id',
@ -74,7 +74,7 @@ class TestBaseManager(unittest.TestCase):
self.pm._publish.assert_called_once_with( self.pm._publish.assert_called_once_with(
{'method': 'emit', 'event': 'foo', 'data': 'bar', {'method': 'emit', 'event': 'foo', 'data': 'bar',
'namespace': '/', 'room': 'baz', 'skip_sid': None, 'namespace': '/', 'room': 'baz', 'skip_sid': None,
'callback': ('baz', '/', '123')}) 'callback': ('baz', '/', '123'), 'host_id': '123456'})
def test_emit_with_callback_without_server(self): def test_emit_with_callback_without_server(self):
standalone_pm = pubsub_manager.PubSubManager() standalone_pm = pubsub_manager.PubSubManager()
@ -141,7 +141,8 @@ class TestBaseManager(unittest.TestCase):
with mock.patch.object(base_manager.BaseManager, 'emit') as super_emit: with mock.patch.object(base_manager.BaseManager, 'emit') as super_emit:
self.pm._handle_emit({'event': 'foo', 'data': 'bar', self.pm._handle_emit({'event': 'foo', 'data': 'bar',
'namespace': '/baz', 'namespace': '/baz',
'callback': ('sid', '/baz', 123)}) 'callback': ('sid', '/baz', 123),
'host_id': host_id})
self.assertEqual(super_emit.call_count, 1) self.assertEqual(super_emit.call_count, 1)
self.assertEqual(super_emit.call_args[0], ('foo', 'bar')) self.assertEqual(super_emit.call_args[0], ('foo', 'bar'))
self.assertEqual(super_emit.call_args[1]['namespace'], '/baz') self.assertEqual(super_emit.call_args[1]['namespace'], '/baz')

Loading…
Cancel
Save