From 00d39ca6985e94e283f5c766c18cacb760e0658d Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Sun, 19 May 2019 18:31:29 +0100 Subject: [PATCH] skip_sid parameter can also be a list (fixes #202) --- socketio/asyncio_manager.py | 4 +++- socketio/base_manager.py | 4 +++- socketio/server.py | 6 ++++-- tests/asyncio/test_asyncio_manager.py | 13 +++++++++++++ tests/common/test_base_manager.py | 14 ++++++++++++++ 5 files changed, 37 insertions(+), 4 deletions(-) diff --git a/socketio/asyncio_manager.py b/socketio/asyncio_manager.py index 01bda69..f4496ec 100644 --- a/socketio/asyncio_manager.py +++ b/socketio/asyncio_manager.py @@ -15,8 +15,10 @@ class AsyncManager(BaseManager): if namespace not in self.rooms or room not in self.rooms[namespace]: return tasks = [] + if not isinstance(skip_sid, list): + skip_sid = [skip_sid] for sid in self.get_participants(namespace, room): - if sid != skip_sid: + if sid not in skip_sid: if callback is not None: id = self._generate_ack_id(sid, namespace, callback) else: diff --git a/socketio/base_manager.py b/socketio/base_manager.py index f8aa998..3cccb85 100644 --- a/socketio/base_manager.py +++ b/socketio/base_manager.py @@ -130,8 +130,10 @@ class BaseManager(object): connected to the namespace.""" if namespace not in self.rooms or room not in self.rooms[namespace]: return + if not isinstance(skip_sid, list): + skip_sid = [skip_sid] for sid in self.get_participants(namespace, room): - if sid != skip_sid: + if sid not in skip_sid: if callback is not None: id = self._generate_ack_id(sid, namespace, callback) else: diff --git a/socketio/server.py b/socketio/server.py index 68bf73c..3dd18a3 100644 --- a/socketio/server.py +++ b/socketio/server.py @@ -218,7 +218,8 @@ class Server(object): connected clients. :param skip_sid: The session ID of a client to skip when broadcasting to a room or to all clients. This can be used to - prevent a message from being sent to the sender. + prevent a message from being sent to the sender. To + skip multiple sids, pass a list. :param namespace: The Socket.IO namespace for the event. If this argument is omitted the event is emitted to the default namespace. @@ -258,7 +259,8 @@ class Server(object): connected clients. :param skip_sid: The session ID of a client to skip when broadcasting to a room or to all clients. This can be used to - prevent a message from being sent to the sender. + prevent a message from being sent to the sender. To + skip multiple sids, pass a list. :param namespace: The Socket.IO namespace for the event. If this argument is omitted the event is emitted to the default namespace. diff --git a/tests/asyncio/test_asyncio_manager.py b/tests/asyncio/test_asyncio_manager.py index 2687a4b..20772fc 100644 --- a/tests/asyncio/test_asyncio_manager.py +++ b/tests/asyncio/test_asyncio_manager.py @@ -255,6 +255,19 @@ class TestAsyncManager(unittest.TestCase): self.bm.server._emit_internal.mock.assert_any_call( '789', 'my event', {'foo': 'bar'}, '/foo', None) + def test_emit_to_all_skip_two(self): + self.bm.connect('123', '/foo') + self.bm.enter_room('123', '/foo', 'bar') + self.bm.connect('456', '/foo') + self.bm.enter_room('456', '/foo', 'bar') + self.bm.connect('789', '/foo') + self.bm.connect('abc', '/bar') + _run(self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo', + skip_sid=['123', '789'])) + self.assertEqual(self.bm.server._emit_internal.mock.call_count, 1) + self.bm.server._emit_internal.mock.assert_any_call( + '456', 'my event', {'foo': 'bar'}, '/foo', None) + def test_emit_with_callback(self): self.bm.connect('123', '/foo') self.bm._generate_ack_id = mock.MagicMock() diff --git a/tests/common/test_base_manager.py b/tests/common/test_base_manager.py index d4291f9..3c59785 100644 --- a/tests/common/test_base_manager.py +++ b/tests/common/test_base_manager.py @@ -237,6 +237,20 @@ class TestBaseManager(unittest.TestCase): {'foo': 'bar'}, '/foo', None) + def test_emit_to_all_skip_two(self): + self.bm.connect('123', '/foo') + self.bm.enter_room('123', '/foo', 'bar') + self.bm.connect('456', '/foo') + self.bm.enter_room('456', '/foo', 'bar') + self.bm.connect('789', '/foo') + self.bm.connect('abc', '/bar') + self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo', + skip_sid=['123', '789']) + self.assertEqual(self.bm.server._emit_internal.call_count, 1) + self.bm.server._emit_internal.assert_any_call('456', 'my event', + {'foo': 'bar'}, '/foo', + None) + def test_emit_with_callback(self): self.bm.connect('123', '/foo') self.bm._generate_ack_id = mock.MagicMock()