Browse Source

Move ack functionality into BaseManager class

pull/5/head
Miguel Grinberg 10 years ago
parent
commit
ad12b837be
  1. 1
      setup.py
  2. 34
      socketio/base_manager.py
  3. 35
      socketio/server.py
  4. 45
      tests/test_base_manager.py
  5. 35
      tests/test_server.py

1
setup.py

@ -29,6 +29,7 @@ setup(
],
tests_require=[
'mock',
'pbr<1.7.0', # temporary, to workaround bug in 1.7.0
],
test_suite='tests',
classifiers=[

34
socketio/base_manager.py

@ -1,3 +1,5 @@
import itertools
import six
@ -14,6 +16,7 @@ class BaseManager(object):
self.server = server
self.rooms = {}
self.pending_removals = []
self.callbacks = {}
def get_namespaces(self):
"""Return an iterable with the active namespace names."""
@ -43,6 +46,10 @@ class BaseManager(object):
rooms.append(room_name)
for room in rooms:
self.leave_room(sid, namespace, room)
if sid in self.callbacks and namespace in self.callbacks[sid]:
del self.callbacks[sid][namespace]
if len(self.callbacks[sid]) == 0:
del self.callbacks[sid]
def enter_room(self, sid, namespace, room):
"""Add a client to a room."""
@ -86,8 +93,31 @@ class BaseManager(object):
return
for sid in self.get_participants(namespace, room):
if sid != skip_sid:
self.server._emit_internal(sid, event, data, namespace,
callback)
if callback is not None:
id = self.server._generate_ack_id(sid, namespace, callback)
else:
id = None
self.server._emit_internal(sid, event, data, namespace, id)
def trigger_callback(self, sid, namespace, id, data):
"""Invoke an application callback."""
try:
callback = self.callbacks[sid][namespace][id]
except KeyError:
raise ValueError('Unknown callback')
del self.callbacks[sid][namespace][id]
callback(*data)
def _generate_ack_id(self, sid, namespace, callback):
"""Generate a unique identifier for an ACK packet."""
namespace = namespace or '/'
if sid not in self.callbacks:
self.callbacks[sid] = {}
if namespace not in self.callbacks[sid]:
self.callbacks[sid][namespace] = {0: itertools.count(1)}
id = six.next(self.callbacks[sid][namespace][0])
self.callbacks[sid][namespace][id] = callback
return id
def _clean_rooms(self):
"""Remove all the inactive room participants."""

35
socketio/server.py

@ -1,4 +1,3 @@
import itertools
import logging
import engineio
@ -83,7 +82,6 @@ class Server(object):
self.environ = {}
self.handlers = {}
self.callbacks = {}
self._binary_packet = None
self._attachment_count = 0
@ -304,12 +302,8 @@ class Server(object):
"""
return self.eio.handle_request(environ, start_response)
def _emit_internal(self, sid, event, data, namespace=None, callback=None):
def _emit_internal(self, sid, event, data, namespace=None, id=None):
"""Send a message to a client."""
if callback is not None:
id = self._generate_ack_id(sid, namespace, callback)
else:
id = None
if six.PY2 and not self.binary:
binary = False # pragma: nocover
else:
@ -353,13 +347,9 @@ class Server(object):
if n != '/' and self.manager.is_connected(sid, n):
self._trigger_event('disconnect', n, sid)
self.manager.disconnect(sid, n)
if sid in self.callbacks and n in self.callbacks[sid]:
del self.callbacks[sid][n]
if namespace == '/' and self.manager.is_connected(sid, namespace):
self._trigger_event('disconnect', '/', sid)
self.manager.disconnect(sid, '/')
if sid in self.callbacks:
del self.callbacks[sid]
if sid in self.environ:
del self.environ[sid]
@ -390,34 +380,13 @@ class Server(object):
"""Handle ACK packets from the client."""
namespace = namespace or '/'
self.logger.info('received ack from %s [%s]', sid, namespace)
self._trigger_callback(sid, namespace, id, data)
self.manager.trigger_callback(sid, namespace, id, data)
def _trigger_event(self, event, namespace, *args):
"""Invoke an application event handler."""
if namespace in self.handlers and event in self.handlers[namespace]:
return self.handlers[namespace][event](*args)
def _generate_ack_id(self, sid, namespace, callback):
"""Generate a unique identifier for an ACK packet."""
namespace = namespace or '/'
if sid not in self.callbacks:
self.callbacks[sid] = {}
if namespace not in self.callbacks[sid]:
self.callbacks[sid][namespace] = {0: itertools.count(1)}
id = six.next(self.callbacks[sid][namespace][0])
self.callbacks[sid][namespace][id] = callback
return id
def _trigger_callback(self, sid, namespace, id, data):
"""Invoke an application callback."""
namespace = namespace or '/'
try:
callback = self.callbacks[sid][namespace][id]
except KeyError:
raise ValueError('Unknown callback')
del self.callbacks[sid][namespace][id]
callback(*data)
def _handle_eio_connect(self, sid, environ):
"""Handle the Engine.IO connection event."""
self.environ[sid] = environ

45
tests/test_base_manager.py

@ -12,7 +12,6 @@ from socketio import base_manager
class TestBaseManager(unittest.TestCase):
def setUp(self):
mock_server = mock.MagicMock()
mock_server.rooms = {}
self.bm = base_manager.BaseManager(mock_server)
def test_connect(self):
@ -78,6 +77,40 @@ class TestBaseManager(unittest.TestCase):
self.bm._clean_rooms()
self.assertEqual(self.bm.rooms, {})
def test_disconnect_with_callbacks(self):
self.bm.connect('123', '/')
self.bm.connect('123', '/foo')
self.bm._generate_ack_id('123', '/', 'f')
self.bm._generate_ack_id('123', '/foo', 'g')
self.bm.disconnect('123', '/foo')
self.assertNotIn('/foo', self.bm.callbacks['123'])
self.bm.disconnect('123', '/')
self.assertNotIn('123', self.bm.callbacks)
def test_trigger_callback(self):
self.bm.connect('123', '/')
self.bm.connect('123', '/foo')
cb = mock.MagicMock()
id1 = self.bm._generate_ack_id('123', '/', cb)
id2 = self.bm._generate_ack_id('123', '/foo', cb)
self.bm.trigger_callback('123', '/', id1, ['foo'])
self.bm.trigger_callback('123', '/foo', id2, ['bar', 'baz'])
self.assertEqual(cb.call_count, 2)
cb.assert_any_call('foo')
cb.assert_any_call('bar', 'baz')
def test_invalid_callback(self):
self.bm.connect('123', '/')
cb = mock.MagicMock()
id = self.bm._generate_ack_id('123', '/', cb)
self.assertRaises(ValueError, self.bm.trigger_callback,
'124', '/', id, ['foo'])
self.assertRaises(ValueError, self.bm.trigger_callback,
'123', '/foo', id, ['foo'])
self.assertRaises(ValueError, self.bm.trigger_callback,
'123', '/', id + 1, ['foo'])
self.assertEqual(cb.call_count, 0)
def test_get_namespaces(self):
self.assertEqual(list(self.bm.get_namespaces()), [])
self.bm.connect('123', '/')
@ -185,6 +218,16 @@ class TestBaseManager(unittest.TestCase):
{'foo': 'bar'}, '/foo',
None)
def test_emit_with_callback(self):
self.bm.connect('123', '/foo')
self.bm.server._generate_ack_id.return_value = 11
self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo',
callback='cb')
self.bm.server._emit_internal.assert_called_once_with('123',
'my event',
{'foo': 'bar'},
'/foo', 11)
def test_emit_to_invalid_room(self):
self.bm.emit('my event', {'foo': 'bar'}, namespace='/', room='123')

35
tests/test_server.py

@ -127,8 +127,8 @@ class TestServer(unittest.TestCase):
def test_emit_internal_with_callback(self, eio):
s = server.Server()
s._emit_internal('123', 'my event', 'my data', namespace='/foo',
callback='cb')
id = s.manager._generate_ack_id('123', '/foo', 'cb')
s._emit_internal('123', 'my event', 'my data', namespace='/foo', id=id)
s.eio.send.assert_called_once_with('123',
'2/foo,1["my event","my data"]',
binary=False)
@ -323,40 +323,25 @@ class TestServer(unittest.TestCase):
def test_send_with_ack(self, eio):
s = server.Server()
cb = mock.MagicMock()
s._handle_eio_connect('123', 'environ')
s._emit_internal('123', 'my event', ['foo'], callback=cb)
s._emit_internal('123', 'my event', ['bar'], callback=cb)
cb = mock.MagicMock()
id1 = s.manager._generate_ack_id('123', '/', cb)
id2 = s.manager._generate_ack_id('123', '/', cb)
s._emit_internal('123', 'my event', ['foo'], id=id1)
s._emit_internal('123', 'my event', ['bar'], id=id2)
s._handle_eio_message('123', '31["foo",2]')
cb.assert_called_once_with('foo', 2)
self.assertIn('123', s.callbacks)
s._handle_disconnect('123', '/')
self.assertNotIn('123', s.callbacks)
def test_send_with_ack_namespace(self, eio):
s = server.Server()
cb = mock.MagicMock()
s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0/foo')
cb = mock.MagicMock()
id = s.manager._generate_ack_id('123', '/foo', cb)
s._emit_internal('123', 'my event', ['foo'], namespace='/foo',
callback=cb)
id=id)
s._handle_eio_message('123', '3/foo,1["foo",2]')
cb.assert_called_once_with('foo', 2)
self.assertIn('/foo', s.callbacks['123'])
s._handle_eio_disconnect('123')
self.assertNotIn('123', s.callbacks)
def test_invalid_callback(self, eio):
s = server.Server()
cb = mock.MagicMock()
s._handle_eio_connect('123', 'environ')
s._emit_internal('123', 'my event', ['foo'], callback=cb)
self.assertRaises(ValueError, s._handle_eio_message, '124',
'31["foo",2]')
self.assertRaises(ValueError, s._handle_eio_message, '123',
'3/foo,1["foo",2]')
self.assertRaises(ValueError, s._handle_eio_message, '123',
'32["foo",2]')
def test_disconnect(self, eio):
s = server.Server()

Loading…
Cancel
Save