diff --git a/setup.py b/setup.py index bd97e11..96be8f6 100755 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ setup( ], tests_require=[ 'mock', + 'pbr<1.7.0', # temporary, to workaround bug in 1.7.0 ], test_suite='tests', classifiers=[ diff --git a/socketio/base_manager.py b/socketio/base_manager.py index fa83d5b..c8e6ff3 100644 --- a/socketio/base_manager.py +++ b/socketio/base_manager.py @@ -1,3 +1,5 @@ +import itertools + import six @@ -14,6 +16,7 @@ class BaseManager(object): self.server = server self.rooms = {} self.pending_removals = [] + self.callbacks = {} def get_namespaces(self): """Return an iterable with the active namespace names.""" @@ -43,6 +46,10 @@ class BaseManager(object): rooms.append(room_name) for room in rooms: self.leave_room(sid, namespace, room) + if sid in self.callbacks and namespace in self.callbacks[sid]: + del self.callbacks[sid][namespace] + if len(self.callbacks[sid]) == 0: + del self.callbacks[sid] def enter_room(self, sid, namespace, room): """Add a client to a room.""" @@ -86,8 +93,31 @@ class BaseManager(object): return for sid in self.get_participants(namespace, room): if sid != skip_sid: - self.server._emit_internal(sid, event, data, namespace, - callback) + if callback is not None: + id = self.server._generate_ack_id(sid, namespace, callback) + else: + id = None + self.server._emit_internal(sid, event, data, namespace, id) + + def trigger_callback(self, sid, namespace, id, data): + """Invoke an application callback.""" + try: + callback = self.callbacks[sid][namespace][id] + except KeyError: + raise ValueError('Unknown callback') + del self.callbacks[sid][namespace][id] + callback(*data) + + def _generate_ack_id(self, sid, namespace, callback): + """Generate a unique identifier for an ACK packet.""" + namespace = namespace or '/' + if sid not in self.callbacks: + self.callbacks[sid] = {} + if namespace not in self.callbacks[sid]: + self.callbacks[sid][namespace] = {0: itertools.count(1)} + id = six.next(self.callbacks[sid][namespace][0]) + self.callbacks[sid][namespace][id] = callback + return id def _clean_rooms(self): """Remove all the inactive room participants.""" diff --git a/socketio/server.py b/socketio/server.py old mode 100755 new mode 100644 index a9cdbb5..0a4dd94 --- a/socketio/server.py +++ b/socketio/server.py @@ -1,4 +1,3 @@ -import itertools import logging import engineio @@ -83,7 +82,6 @@ class Server(object): self.environ = {} self.handlers = {} - self.callbacks = {} self._binary_packet = None self._attachment_count = 0 @@ -304,12 +302,8 @@ class Server(object): """ return self.eio.handle_request(environ, start_response) - def _emit_internal(self, sid, event, data, namespace=None, callback=None): + def _emit_internal(self, sid, event, data, namespace=None, id=None): """Send a message to a client.""" - if callback is not None: - id = self._generate_ack_id(sid, namespace, callback) - else: - id = None if six.PY2 and not self.binary: binary = False # pragma: nocover else: @@ -353,13 +347,9 @@ class Server(object): if n != '/' and self.manager.is_connected(sid, n): self._trigger_event('disconnect', n, sid) self.manager.disconnect(sid, n) - if sid in self.callbacks and n in self.callbacks[sid]: - del self.callbacks[sid][n] if namespace == '/' and self.manager.is_connected(sid, namespace): self._trigger_event('disconnect', '/', sid) self.manager.disconnect(sid, '/') - if sid in self.callbacks: - del self.callbacks[sid] if sid in self.environ: del self.environ[sid] @@ -390,34 +380,13 @@ class Server(object): """Handle ACK packets from the client.""" namespace = namespace or '/' self.logger.info('received ack from %s [%s]', sid, namespace) - self._trigger_callback(sid, namespace, id, data) + self.manager.trigger_callback(sid, namespace, id, data) def _trigger_event(self, event, namespace, *args): """Invoke an application event handler.""" if namespace in self.handlers and event in self.handlers[namespace]: return self.handlers[namespace][event](*args) - def _generate_ack_id(self, sid, namespace, callback): - """Generate a unique identifier for an ACK packet.""" - namespace = namespace or '/' - if sid not in self.callbacks: - self.callbacks[sid] = {} - if namespace not in self.callbacks[sid]: - self.callbacks[sid][namespace] = {0: itertools.count(1)} - id = six.next(self.callbacks[sid][namespace][0]) - self.callbacks[sid][namespace][id] = callback - return id - - def _trigger_callback(self, sid, namespace, id, data): - """Invoke an application callback.""" - namespace = namespace or '/' - try: - callback = self.callbacks[sid][namespace][id] - except KeyError: - raise ValueError('Unknown callback') - del self.callbacks[sid][namespace][id] - callback(*data) - def _handle_eio_connect(self, sid, environ): """Handle the Engine.IO connection event.""" self.environ[sid] = environ diff --git a/tests/test_base_manager.py b/tests/test_base_manager.py index ffb1f43..9790777 100644 --- a/tests/test_base_manager.py +++ b/tests/test_base_manager.py @@ -12,7 +12,6 @@ from socketio import base_manager class TestBaseManager(unittest.TestCase): def setUp(self): mock_server = mock.MagicMock() - mock_server.rooms = {} self.bm = base_manager.BaseManager(mock_server) def test_connect(self): @@ -78,6 +77,40 @@ class TestBaseManager(unittest.TestCase): self.bm._clean_rooms() self.assertEqual(self.bm.rooms, {}) + def test_disconnect_with_callbacks(self): + self.bm.connect('123', '/') + self.bm.connect('123', '/foo') + self.bm._generate_ack_id('123', '/', 'f') + self.bm._generate_ack_id('123', '/foo', 'g') + self.bm.disconnect('123', '/foo') + self.assertNotIn('/foo', self.bm.callbacks['123']) + self.bm.disconnect('123', '/') + self.assertNotIn('123', self.bm.callbacks) + + def test_trigger_callback(self): + self.bm.connect('123', '/') + self.bm.connect('123', '/foo') + cb = mock.MagicMock() + id1 = self.bm._generate_ack_id('123', '/', cb) + id2 = self.bm._generate_ack_id('123', '/foo', cb) + self.bm.trigger_callback('123', '/', id1, ['foo']) + self.bm.trigger_callback('123', '/foo', id2, ['bar', 'baz']) + self.assertEqual(cb.call_count, 2) + cb.assert_any_call('foo') + cb.assert_any_call('bar', 'baz') + + def test_invalid_callback(self): + self.bm.connect('123', '/') + cb = mock.MagicMock() + id = self.bm._generate_ack_id('123', '/', cb) + self.assertRaises(ValueError, self.bm.trigger_callback, + '124', '/', id, ['foo']) + self.assertRaises(ValueError, self.bm.trigger_callback, + '123', '/foo', id, ['foo']) + self.assertRaises(ValueError, self.bm.trigger_callback, + '123', '/', id + 1, ['foo']) + self.assertEqual(cb.call_count, 0) + def test_get_namespaces(self): self.assertEqual(list(self.bm.get_namespaces()), []) self.bm.connect('123', '/') @@ -185,6 +218,16 @@ class TestBaseManager(unittest.TestCase): {'foo': 'bar'}, '/foo', None) + def test_emit_with_callback(self): + self.bm.connect('123', '/foo') + self.bm.server._generate_ack_id.return_value = 11 + self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo', + callback='cb') + self.bm.server._emit_internal.assert_called_once_with('123', + 'my event', + {'foo': 'bar'}, + '/foo', 11) + def test_emit_to_invalid_room(self): self.bm.emit('my event', {'foo': 'bar'}, namespace='/', room='123') diff --git a/tests/test_server.py b/tests/test_server.py index 8e3a868..dad08b7 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -127,8 +127,8 @@ class TestServer(unittest.TestCase): def test_emit_internal_with_callback(self, eio): s = server.Server() - s._emit_internal('123', 'my event', 'my data', namespace='/foo', - callback='cb') + id = s.manager._generate_ack_id('123', '/foo', 'cb') + s._emit_internal('123', 'my event', 'my data', namespace='/foo', id=id) s.eio.send.assert_called_once_with('123', '2/foo,1["my event","my data"]', binary=False) @@ -323,40 +323,25 @@ class TestServer(unittest.TestCase): def test_send_with_ack(self, eio): s = server.Server() - cb = mock.MagicMock() s._handle_eio_connect('123', 'environ') - s._emit_internal('123', 'my event', ['foo'], callback=cb) - s._emit_internal('123', 'my event', ['bar'], callback=cb) + cb = mock.MagicMock() + id1 = s.manager._generate_ack_id('123', '/', cb) + id2 = s.manager._generate_ack_id('123', '/', cb) + s._emit_internal('123', 'my event', ['foo'], id=id1) + s._emit_internal('123', 'my event', ['bar'], id=id2) s._handle_eio_message('123', '31["foo",2]') cb.assert_called_once_with('foo', 2) - self.assertIn('123', s.callbacks) - s._handle_disconnect('123', '/') - self.assertNotIn('123', s.callbacks) def test_send_with_ack_namespace(self, eio): s = server.Server() - cb = mock.MagicMock() s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0/foo') + cb = mock.MagicMock() + id = s.manager._generate_ack_id('123', '/foo', cb) s._emit_internal('123', 'my event', ['foo'], namespace='/foo', - callback=cb) + id=id) s._handle_eio_message('123', '3/foo,1["foo",2]') cb.assert_called_once_with('foo', 2) - self.assertIn('/foo', s.callbacks['123']) - s._handle_eio_disconnect('123') - self.assertNotIn('123', s.callbacks) - - def test_invalid_callback(self, eio): - s = server.Server() - cb = mock.MagicMock() - s._handle_eio_connect('123', 'environ') - s._emit_internal('123', 'my event', ['foo'], callback=cb) - self.assertRaises(ValueError, s._handle_eio_message, '124', - '31["foo",2]') - self.assertRaises(ValueError, s._handle_eio_message, '123', - '3/foo,1["foo",2]') - self.assertRaises(ValueError, s._handle_eio_message, '123', - '32["foo",2]') def test_disconnect(self, eio): s = server.Server()