Browse Source

Move ack functionality into BaseManager class

pull/5/head
Miguel Grinberg 10 years ago
parent
commit
ad12b837be
  1. 1
      setup.py
  2. 34
      socketio/base_manager.py
  3. 35
      socketio/server.py
  4. 45
      tests/test_base_manager.py
  5. 35
      tests/test_server.py

1
setup.py

@ -29,6 +29,7 @@ setup(
], ],
tests_require=[ tests_require=[
'mock', 'mock',
'pbr<1.7.0', # temporary, to workaround bug in 1.7.0
], ],
test_suite='tests', test_suite='tests',
classifiers=[ classifiers=[

34
socketio/base_manager.py

@ -1,3 +1,5 @@
import itertools
import six import six
@ -14,6 +16,7 @@ class BaseManager(object):
self.server = server self.server = server
self.rooms = {} self.rooms = {}
self.pending_removals = [] self.pending_removals = []
self.callbacks = {}
def get_namespaces(self): def get_namespaces(self):
"""Return an iterable with the active namespace names.""" """Return an iterable with the active namespace names."""
@ -43,6 +46,10 @@ class BaseManager(object):
rooms.append(room_name) rooms.append(room_name)
for room in rooms: for room in rooms:
self.leave_room(sid, namespace, room) self.leave_room(sid, namespace, room)
if sid in self.callbacks and namespace in self.callbacks[sid]:
del self.callbacks[sid][namespace]
if len(self.callbacks[sid]) == 0:
del self.callbacks[sid]
def enter_room(self, sid, namespace, room): def enter_room(self, sid, namespace, room):
"""Add a client to a room.""" """Add a client to a room."""
@ -86,8 +93,31 @@ class BaseManager(object):
return return
for sid in self.get_participants(namespace, room): for sid in self.get_participants(namespace, room):
if sid != skip_sid: if sid != skip_sid:
self.server._emit_internal(sid, event, data, namespace, if callback is not None:
callback) id = self.server._generate_ack_id(sid, namespace, callback)
else:
id = None
self.server._emit_internal(sid, event, data, namespace, id)
def trigger_callback(self, sid, namespace, id, data):
"""Invoke an application callback."""
try:
callback = self.callbacks[sid][namespace][id]
except KeyError:
raise ValueError('Unknown callback')
del self.callbacks[sid][namespace][id]
callback(*data)
def _generate_ack_id(self, sid, namespace, callback):
"""Generate a unique identifier for an ACK packet."""
namespace = namespace or '/'
if sid not in self.callbacks:
self.callbacks[sid] = {}
if namespace not in self.callbacks[sid]:
self.callbacks[sid][namespace] = {0: itertools.count(1)}
id = six.next(self.callbacks[sid][namespace][0])
self.callbacks[sid][namespace][id] = callback
return id
def _clean_rooms(self): def _clean_rooms(self):
"""Remove all the inactive room participants.""" """Remove all the inactive room participants."""

35
socketio/server.py

@ -1,4 +1,3 @@
import itertools
import logging import logging
import engineio import engineio
@ -83,7 +82,6 @@ class Server(object):
self.environ = {} self.environ = {}
self.handlers = {} self.handlers = {}
self.callbacks = {}
self._binary_packet = None self._binary_packet = None
self._attachment_count = 0 self._attachment_count = 0
@ -304,12 +302,8 @@ class Server(object):
""" """
return self.eio.handle_request(environ, start_response) return self.eio.handle_request(environ, start_response)
def _emit_internal(self, sid, event, data, namespace=None, callback=None): def _emit_internal(self, sid, event, data, namespace=None, id=None):
"""Send a message to a client.""" """Send a message to a client."""
if callback is not None:
id = self._generate_ack_id(sid, namespace, callback)
else:
id = None
if six.PY2 and not self.binary: if six.PY2 and not self.binary:
binary = False # pragma: nocover binary = False # pragma: nocover
else: else:
@ -353,13 +347,9 @@ class Server(object):
if n != '/' and self.manager.is_connected(sid, n): if n != '/' and self.manager.is_connected(sid, n):
self._trigger_event('disconnect', n, sid) self._trigger_event('disconnect', n, sid)
self.manager.disconnect(sid, n) self.manager.disconnect(sid, n)
if sid in self.callbacks and n in self.callbacks[sid]:
del self.callbacks[sid][n]
if namespace == '/' and self.manager.is_connected(sid, namespace): if namespace == '/' and self.manager.is_connected(sid, namespace):
self._trigger_event('disconnect', '/', sid) self._trigger_event('disconnect', '/', sid)
self.manager.disconnect(sid, '/') self.manager.disconnect(sid, '/')
if sid in self.callbacks:
del self.callbacks[sid]
if sid in self.environ: if sid in self.environ:
del self.environ[sid] del self.environ[sid]
@ -390,34 +380,13 @@ class Server(object):
"""Handle ACK packets from the client.""" """Handle ACK packets from the client."""
namespace = namespace or '/' namespace = namespace or '/'
self.logger.info('received ack from %s [%s]', sid, namespace) self.logger.info('received ack from %s [%s]', sid, namespace)
self._trigger_callback(sid, namespace, id, data) self.manager.trigger_callback(sid, namespace, id, data)
def _trigger_event(self, event, namespace, *args): def _trigger_event(self, event, namespace, *args):
"""Invoke an application event handler.""" """Invoke an application event handler."""
if namespace in self.handlers and event in self.handlers[namespace]: if namespace in self.handlers and event in self.handlers[namespace]:
return self.handlers[namespace][event](*args) return self.handlers[namespace][event](*args)
def _generate_ack_id(self, sid, namespace, callback):
"""Generate a unique identifier for an ACK packet."""
namespace = namespace or '/'
if sid not in self.callbacks:
self.callbacks[sid] = {}
if namespace not in self.callbacks[sid]:
self.callbacks[sid][namespace] = {0: itertools.count(1)}
id = six.next(self.callbacks[sid][namespace][0])
self.callbacks[sid][namespace][id] = callback
return id
def _trigger_callback(self, sid, namespace, id, data):
"""Invoke an application callback."""
namespace = namespace or '/'
try:
callback = self.callbacks[sid][namespace][id]
except KeyError:
raise ValueError('Unknown callback')
del self.callbacks[sid][namespace][id]
callback(*data)
def _handle_eio_connect(self, sid, environ): def _handle_eio_connect(self, sid, environ):
"""Handle the Engine.IO connection event.""" """Handle the Engine.IO connection event."""
self.environ[sid] = environ self.environ[sid] = environ

45
tests/test_base_manager.py

@ -12,7 +12,6 @@ from socketio import base_manager
class TestBaseManager(unittest.TestCase): class TestBaseManager(unittest.TestCase):
def setUp(self): def setUp(self):
mock_server = mock.MagicMock() mock_server = mock.MagicMock()
mock_server.rooms = {}
self.bm = base_manager.BaseManager(mock_server) self.bm = base_manager.BaseManager(mock_server)
def test_connect(self): def test_connect(self):
@ -78,6 +77,40 @@ class TestBaseManager(unittest.TestCase):
self.bm._clean_rooms() self.bm._clean_rooms()
self.assertEqual(self.bm.rooms, {}) self.assertEqual(self.bm.rooms, {})
def test_disconnect_with_callbacks(self):
self.bm.connect('123', '/')
self.bm.connect('123', '/foo')
self.bm._generate_ack_id('123', '/', 'f')
self.bm._generate_ack_id('123', '/foo', 'g')
self.bm.disconnect('123', '/foo')
self.assertNotIn('/foo', self.bm.callbacks['123'])
self.bm.disconnect('123', '/')
self.assertNotIn('123', self.bm.callbacks)
def test_trigger_callback(self):
self.bm.connect('123', '/')
self.bm.connect('123', '/foo')
cb = mock.MagicMock()
id1 = self.bm._generate_ack_id('123', '/', cb)
id2 = self.bm._generate_ack_id('123', '/foo', cb)
self.bm.trigger_callback('123', '/', id1, ['foo'])
self.bm.trigger_callback('123', '/foo', id2, ['bar', 'baz'])
self.assertEqual(cb.call_count, 2)
cb.assert_any_call('foo')
cb.assert_any_call('bar', 'baz')
def test_invalid_callback(self):
self.bm.connect('123', '/')
cb = mock.MagicMock()
id = self.bm._generate_ack_id('123', '/', cb)
self.assertRaises(ValueError, self.bm.trigger_callback,
'124', '/', id, ['foo'])
self.assertRaises(ValueError, self.bm.trigger_callback,
'123', '/foo', id, ['foo'])
self.assertRaises(ValueError, self.bm.trigger_callback,
'123', '/', id + 1, ['foo'])
self.assertEqual(cb.call_count, 0)
def test_get_namespaces(self): def test_get_namespaces(self):
self.assertEqual(list(self.bm.get_namespaces()), []) self.assertEqual(list(self.bm.get_namespaces()), [])
self.bm.connect('123', '/') self.bm.connect('123', '/')
@ -185,6 +218,16 @@ class TestBaseManager(unittest.TestCase):
{'foo': 'bar'}, '/foo', {'foo': 'bar'}, '/foo',
None) None)
def test_emit_with_callback(self):
self.bm.connect('123', '/foo')
self.bm.server._generate_ack_id.return_value = 11
self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo',
callback='cb')
self.bm.server._emit_internal.assert_called_once_with('123',
'my event',
{'foo': 'bar'},
'/foo', 11)
def test_emit_to_invalid_room(self): def test_emit_to_invalid_room(self):
self.bm.emit('my event', {'foo': 'bar'}, namespace='/', room='123') self.bm.emit('my event', {'foo': 'bar'}, namespace='/', room='123')

35
tests/test_server.py

@ -127,8 +127,8 @@ class TestServer(unittest.TestCase):
def test_emit_internal_with_callback(self, eio): def test_emit_internal_with_callback(self, eio):
s = server.Server() s = server.Server()
s._emit_internal('123', 'my event', 'my data', namespace='/foo', id = s.manager._generate_ack_id('123', '/foo', 'cb')
callback='cb') s._emit_internal('123', 'my event', 'my data', namespace='/foo', id=id)
s.eio.send.assert_called_once_with('123', s.eio.send.assert_called_once_with('123',
'2/foo,1["my event","my data"]', '2/foo,1["my event","my data"]',
binary=False) binary=False)
@ -323,40 +323,25 @@ class TestServer(unittest.TestCase):
def test_send_with_ack(self, eio): def test_send_with_ack(self, eio):
s = server.Server() s = server.Server()
cb = mock.MagicMock()
s._handle_eio_connect('123', 'environ') s._handle_eio_connect('123', 'environ')
s._emit_internal('123', 'my event', ['foo'], callback=cb) cb = mock.MagicMock()
s._emit_internal('123', 'my event', ['bar'], callback=cb) id1 = s.manager._generate_ack_id('123', '/', cb)
id2 = s.manager._generate_ack_id('123', '/', cb)
s._emit_internal('123', 'my event', ['foo'], id=id1)
s._emit_internal('123', 'my event', ['bar'], id=id2)
s._handle_eio_message('123', '31["foo",2]') s._handle_eio_message('123', '31["foo",2]')
cb.assert_called_once_with('foo', 2) cb.assert_called_once_with('foo', 2)
self.assertIn('123', s.callbacks)
s._handle_disconnect('123', '/')
self.assertNotIn('123', s.callbacks)
def test_send_with_ack_namespace(self, eio): def test_send_with_ack_namespace(self, eio):
s = server.Server() s = server.Server()
cb = mock.MagicMock()
s._handle_eio_connect('123', 'environ') s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0/foo') s._handle_eio_message('123', '0/foo')
cb = mock.MagicMock()
id = s.manager._generate_ack_id('123', '/foo', cb)
s._emit_internal('123', 'my event', ['foo'], namespace='/foo', s._emit_internal('123', 'my event', ['foo'], namespace='/foo',
callback=cb) id=id)
s._handle_eio_message('123', '3/foo,1["foo",2]') s._handle_eio_message('123', '3/foo,1["foo",2]')
cb.assert_called_once_with('foo', 2) cb.assert_called_once_with('foo', 2)
self.assertIn('/foo', s.callbacks['123'])
s._handle_eio_disconnect('123')
self.assertNotIn('123', s.callbacks)
def test_invalid_callback(self, eio):
s = server.Server()
cb = mock.MagicMock()
s._handle_eio_connect('123', 'environ')
s._emit_internal('123', 'my event', ['foo'], callback=cb)
self.assertRaises(ValueError, s._handle_eio_message, '124',
'31["foo",2]')
self.assertRaises(ValueError, s._handle_eio_message, '123',
'3/foo,1["foo",2]')
self.assertRaises(ValueError, s._handle_eio_message, '123',
'32["foo",2]')
def test_disconnect(self, eio): def test_disconnect(self, eio):
s = server.Server() s = server.Server()

Loading…
Cancel
Save