diff --git a/src/socketio/asyncio_manager.py b/src/socketio/asyncio_manager.py index 4a014e7..20a16c8 100644 --- a/src/socketio/asyncio_manager.py +++ b/src/socketio/asyncio_manager.py @@ -33,6 +33,13 @@ class AsyncManager(BaseManager): return await asyncio.wait(tasks) + async def disconnect(self, sid, namespace, **kwargs): + """Disconnect a client. + + Note: this method is a coroutine. + """ + return super().disconnect(sid, namespace, **kwargs) + async def close_room(self, room, namespace): """Remove all participants from a room. diff --git a/src/socketio/asyncio_pubsub_manager.py b/src/socketio/asyncio_pubsub_manager.py index ac261a1..1a06889 100644 --- a/src/socketio/asyncio_pubsub_manager.py +++ b/src/socketio/asyncio_pubsub_manager.py @@ -76,7 +76,14 @@ class AsyncPubSubManager(AsyncManager): else: # client is in another server, so we post request to the queue await self._publish({'method': 'disconnect', 'sid': sid, - 'namespace': namespace or '/'}) + 'namespace': namespace or '/'}) + + async def disconnect(self, sid, namespace, **kwargs): + if kwargs.get('ignore_queue'): + return await super(AsyncPubSubManager, self).disconnect( + sid, namespace=namespace) + await self._publish({'method': 'disconnect', 'sid': sid, + 'namespace': namespace or '/'}) async def close_room(self, room, namespace=None): await self._publish({'method': 'close_room', 'room': room, diff --git a/src/socketio/asyncio_server.py b/src/socketio/asyncio_server.py index ccd55fb..eb708b2 100644 --- a/src/socketio/asyncio_server.py +++ b/src/socketio/asyncio_server.py @@ -384,7 +384,8 @@ class AsyncServer(server.Server): await self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, namespace=namespace)) await self._trigger_event('disconnect', namespace, sid) - self.manager.disconnect(sid, namespace=namespace) + await self.manager.disconnect(sid, namespace=namespace, + ignore_queue=True) async def handle_request(self, *args, **kwargs): """Handle an HTTP request from the client. @@ -486,7 +487,7 @@ class AsyncServer(server.Server): await self._send_packet(eio_sid, self.packet_class( packet.CONNECT_ERROR, data=fail_reason, namespace=namespace)) - self.manager.disconnect(sid, namespace) + await self.manager.disconnect(sid, namespace, ignore_queue=True) elif not self.always_connect: await self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) @@ -499,7 +500,7 @@ class AsyncServer(server.Server): return self.manager.pre_disconnect(sid, namespace=namespace) await self._trigger_event('disconnect', namespace, sid) - self.manager.disconnect(sid, namespace) + await self.manager.disconnect(sid, namespace, ignore_queue=True) async def _handle_event(self, eio_sid, namespace, id, data): """Handle an incoming client event.""" diff --git a/src/socketio/base_manager.py b/src/socketio/base_manager.py index 45eb85b..87d2387 100644 --- a/src/socketio/base_manager.py +++ b/src/socketio/base_manager.py @@ -68,6 +68,7 @@ class BaseManager(object): return self.rooms[namespace][None][sid] is not None except KeyError: pass + return False def sid_from_eio_sid(self, eio_sid, namespace): try: diff --git a/tests/asyncio/test_asyncio_manager.py b/tests/asyncio/test_asyncio_manager.py index d51ba47..32836bf 100644 --- a/tests/asyncio/test_asyncio_manager.py +++ b/tests/asyncio/test_asyncio_manager.py @@ -59,9 +59,9 @@ class TestAsyncManager(unittest.TestCase): assert self.bm.pre_disconnect(sid2, '/foo') == '456' assert self.bm.pending_disconnect == {'/foo': [sid1, sid2]} assert not self.bm.is_connected(sid2, '/foo') - self.bm.disconnect(sid1, '/foo') + _run(self.bm.disconnect(sid1, '/foo')) assert self.bm.pending_disconnect == {'/foo': [sid2]} - self.bm.disconnect(sid2, '/foo') + _run(self.bm.disconnect(sid2, '/foo')) assert self.bm.pending_disconnect == {} def test_disconnect(self): @@ -69,7 +69,7 @@ class TestAsyncManager(unittest.TestCase): sid2 = self.bm.connect('456', '/foo') self.bm.enter_room(sid1, '/foo', 'bar') self.bm.enter_room(sid2, '/foo', 'baz') - self.bm.disconnect(sid1, '/foo') + _run(self.bm.disconnect(sid1, '/foo')) assert dict(self.bm.rooms['/foo'][None]) == {sid2: '456'} assert dict(self.bm.rooms['/foo'][sid2]) == {sid2: '456'} assert dict(self.bm.rooms['/foo']['baz']) == {sid2: '456'} @@ -83,10 +83,10 @@ class TestAsyncManager(unittest.TestCase): assert self.bm.is_connected(sid2, '/foo') assert not self.bm.is_connected(sid2, '/') assert not self.bm.is_connected(sid1, '/foo') - self.bm.disconnect(sid1, '/') + _run(self.bm.disconnect(sid1, '/')) assert not self.bm.is_connected(sid1, '/') assert self.bm.is_connected(sid2, '/foo') - self.bm.disconnect(sid2, '/foo') + _run(self.bm.disconnect(sid2, '/foo')) assert not self.bm.is_connected(sid2, '/foo') assert dict(self.bm.rooms['/'][None]) == {sid3: '456'} assert dict(self.bm.rooms['/'][sid3]) == {sid3: '456'} @@ -98,10 +98,10 @@ class TestAsyncManager(unittest.TestCase): sid2 = self.bm.connect('123', '/foo') sid3 = self.bm.connect('456', '/') sid4 = self.bm.connect('456', '/foo') - self.bm.disconnect(sid1, '/') - self.bm.disconnect(sid2, '/foo') - self.bm.disconnect(sid1, '/') - self.bm.disconnect(sid2, '/foo') + _run(self.bm.disconnect(sid1, '/')) + _run(self.bm.disconnect(sid2, '/foo')) + _run(self.bm.disconnect(sid1, '/')) + _run(self.bm.disconnect(sid2, '/foo')) assert dict(self.bm.rooms['/'][None]) == {sid3: '456'} assert dict(self.bm.rooms['/'][sid3]) == {sid3: '456'} assert dict(self.bm.rooms['/foo'][None]) == {sid4: '456'} @@ -112,8 +112,8 @@ class TestAsyncManager(unittest.TestCase): sid2 = self.bm.connect('456', '/foo') self.bm.enter_room(sid1, '/foo', 'bar') self.bm.enter_room(sid2, '/foo', 'baz') - self.bm.disconnect(sid1, '/foo') - self.bm.disconnect(sid2, '/foo') + _run(self.bm.disconnect(sid1, '/foo')) + _run(self.bm.disconnect(sid2, '/foo')) assert self.bm.rooms == {} def test_disconnect_with_callbacks(self): @@ -123,9 +123,9 @@ class TestAsyncManager(unittest.TestCase): self.bm._generate_ack_id(sid1, 'f') self.bm._generate_ack_id(sid2, 'g') self.bm._generate_ack_id(sid3, 'h') - self.bm.disconnect(sid2, '/foo') + _run(self.bm.disconnect(sid2, '/foo')) assert sid2 not in self.bm.callbacks - self.bm.disconnect(sid1, '/') + _run(self.bm.disconnect(sid1, '/')) assert sid1 not in self.bm.callbacks assert sid3 in self.bm.callbacks @@ -176,7 +176,7 @@ class TestAsyncManager(unittest.TestCase): sid1 = self.bm.connect('123', '/') sid2 = self.bm.connect('456', '/') sid3 = self.bm.connect('789', '/') - self.bm.disconnect(sid3, '/') + _run(self.bm.disconnect(sid3, '/')) assert sid3 not in self.bm.rooms['/'][None] participants = list(self.bm.get_participants('/', None)) assert len(participants) == 2 diff --git a/tests/asyncio/test_asyncio_pubsub_manager.py b/tests/asyncio/test_asyncio_pubsub_manager.py index c95c073..80a821b 100644 --- a/tests/asyncio/test_asyncio_pubsub_manager.py +++ b/tests/asyncio/test_asyncio_pubsub_manager.py @@ -176,6 +176,19 @@ class TestAsyncPubSubManager(unittest.TestCase): {'method': 'disconnect', 'sid': sid, 'namespace': '/foo'} ) + def test_disconnect(self): + _run(self.pm.disconnect('foo', '/')) + self.pm._publish.mock.assert_called_once_with( + {'method': 'disconnect', 'sid': 'foo', 'namespace': '/'} + ) + + def test_disconnect_ignore_queue(self): + sid = self.pm.connect('123', '/') + self.pm.pre_disconnect(sid, '/') + _run(self.pm.disconnect(sid, '/', ignore_queue=True)) + self.pm._publish.mock.assert_not_called() + assert self.pm.is_connected(sid, '/') is False + def test_close_room(self): _run(self.pm.close_room('foo')) self.pm._publish.mock.assert_called_once_with( diff --git a/tests/asyncio/test_asyncio_server.py b/tests/asyncio/test_asyncio_server.py index eec531c..b9e4d7e 100644 --- a/tests/asyncio/test_asyncio_server.py +++ b/tests/asyncio/test_asyncio_server.py @@ -597,14 +597,15 @@ class TestAsyncServer(unittest.TestCase): def test_handle_disconnect(self, eio): eio.return_value.send = AsyncMock() s = asyncio_server.AsyncServer() - s.manager.disconnect = mock.MagicMock() + s.manager.disconnect = AsyncMock() handler = mock.MagicMock() s.on('disconnect', handler) _run(s._handle_eio_connect('123', 'environ')) _run(s._handle_eio_message('123', '0')) _run(s._handle_eio_disconnect('123')) handler.assert_called_once_with('1') - s.manager.disconnect.assert_called_once_with('1', '/') + s.manager.disconnect.mock.assert_called_once_with( + '1', '/', ignore_queue=True) assert s.environ == {} def test_handle_disconnect_namespace(self, eio):