From 516a2958f4e87041aeeea0a0a8e3622d3d636184 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Sat, 3 Aug 2019 20:41:58 +0100 Subject: [PATCH] Disconnect Engine.IO connection when server disconnects a client (https://github.com/miguelgrinberg/Flask-SocketIO/issues/1017) --- socketio/asyncio_client.py | 22 ++++++++++-- socketio/asyncio_server.py | 2 ++ socketio/client.py | 24 ++++++++++--- socketio/server.py | 2 ++ tests/asyncio/test_asyncio_client.py | 52 ++++++++++++++++++++++++++++ tests/asyncio/test_asyncio_server.py | 6 ++++ tests/common/test_client.py | 49 ++++++++++++++++++++++++++ 7 files changed, 150 insertions(+), 7 deletions(-) diff --git a/socketio/asyncio_client.py b/socketio/asyncio_client.py index d71ccec..40b182c 100644 --- a/socketio/asyncio_client.py +++ b/socketio/asyncio_client.py @@ -102,6 +102,7 @@ class AsyncClient(client.Client): engineio_path=socketio_path) except engineio.exceptions.ConnectionError as exc: six.raise_from(exceptions.ConnectionError(exc.args[0]), None) + self.connected = True async def wait(self): """Wait until the connection with the server ends. @@ -232,6 +233,7 @@ class AsyncClient(client.Client): namespace=n)) await self._send_packet(packet.Packet( packet.DISCONNECT, namespace='/')) + self.connected = False await self.eio.disconnect(abort=True) def start_background_task(self, target, *args, **kwargs): @@ -286,10 +288,18 @@ class AsyncClient(client.Client): self.namespaces.append(namespace) async def _handle_disconnect(self, namespace): + if not self.connected: + return namespace = namespace or '/' + if namespace == '/': + for n in self.namespaces: + await self._trigger_event('disconnect', namespace=n) + self.namespaces = [] await self._trigger_event('disconnect', namespace=namespace) if namespace in self.namespaces: self.namespaces.remove(namespace) + if namespace == '/': + self.connected = False async def _handle_event(self, namespace, id, data): namespace = namespace or '/' @@ -335,6 +345,9 @@ class AsyncClient(client.Client): namespace)) if namespace in self.namespaces: self.namespaces.remove(namespace) + if namespace == '/': + self.namespaces = [] + self.connected = False async def _trigger_event(self, event, namespace, *args): """Invoke an application event handler.""" @@ -431,9 +444,12 @@ class AsyncClient(client.Client): """Handle the Engine.IO disconnection event.""" self.logger.info('Engine.IO connection dropped') self._reconnect_abort.set() - for n in self.namespaces: - await self._trigger_event('disconnect', namespace=n) - await self._trigger_event('disconnect', namespace='/') + if self.connected: + for n in self.namespaces: + await self._trigger_event('disconnect', namespace=n) + await self._trigger_event('disconnect', namespace='/') + self.namespaces = [] + self.connected = False self.callbacks = {} self._binary_packet = None self.sid = None diff --git a/socketio/asyncio_server.py b/socketio/asyncio_server.py index 412e94c..8162452 100644 --- a/socketio/asyncio_server.py +++ b/socketio/asyncio_server.py @@ -312,6 +312,8 @@ class AsyncServer(server.Server): namespace=namespace)) await self._trigger_event('disconnect', namespace, sid) self.manager.disconnect(sid, namespace=namespace) + if namespace == '/': + await self.eio.disconnect(sid) async def handle_request(self, *args, **kwargs): """Handle an HTTP request from the client. diff --git a/socketio/client.py b/socketio/client.py index c0bdc1a..2c5d8b2 100644 --- a/socketio/client.py +++ b/socketio/client.py @@ -112,6 +112,7 @@ class Client(object): self.socketio_path = None self.sid = None + self.connected = False self.namespaces = [] self.handlers = {} self.namespace_handlers = {} @@ -261,6 +262,7 @@ class Client(object): engineio_path=socketio_path) except engineio.exceptions.ConnectionError as exc: six.raise_from(exceptions.ConnectionError(exc.args[0]), None) + self.connected = True def wait(self): """Wait until the connection with the server ends. @@ -377,6 +379,7 @@ class Client(object): self._send_packet(packet.Packet(packet.DISCONNECT, namespace=n)) self._send_packet(packet.Packet( packet.DISCONNECT, namespace='/')) + self.connected = False self.eio.disconnect(abort=True) def transport(self): @@ -445,10 +448,18 @@ class Client(object): self.namespaces.append(namespace) def _handle_disconnect(self, namespace): + if not self.connected: + return namespace = namespace or '/' + if namespace == '/': + for n in self.namespaces: + self._trigger_event('disconnect', namespace=n) + self.namespaces = [] self._trigger_event('disconnect', namespace=namespace) if namespace in self.namespaces: self.namespaces.remove(namespace) + if namespace == '/': + self.connected = False def _handle_event(self, namespace, id, data): namespace = namespace or '/' @@ -490,6 +501,9 @@ class Client(object): namespace)) if namespace in self.namespaces: self.namespaces.remove(namespace) + if namespace == '/': + self.namespaces = [] + self.connected = False def _trigger_event(self, event, namespace, *args): """Invoke an application event handler.""" @@ -516,7 +530,6 @@ class Client(object): self.logger.info( 'Connection failed, new attempt in {:.02f} seconds'.format( delay)) - print('***', self._reconnect_abort.wait) if self._reconnect_abort.wait(delay): self.logger.info('Reconnect task aborted') break @@ -576,9 +589,12 @@ class Client(object): def _handle_eio_disconnect(self): """Handle the Engine.IO disconnection event.""" self.logger.info('Engine.IO connection dropped') - for n in self.namespaces: - self._trigger_event('disconnect', namespace=n) - self._trigger_event('disconnect', namespace='/') + if self.connected: + for n in self.namespaces: + self._trigger_event('disconnect', namespace=n) + self._trigger_event('disconnect', namespace='/') + self.namespaces = [] + self.connected = False self.callbacks = {} self._binary_packet = None self.sid = None diff --git a/socketio/server.py b/socketio/server.py index 1fc3a10..2ddd94c 100644 --- a/socketio/server.py +++ b/socketio/server.py @@ -504,6 +504,8 @@ class Server(object): namespace=namespace)) self._trigger_event('disconnect', namespace, sid) self.manager.disconnect(sid, namespace=namespace) + if namespace == '/': + self.eio.disconnect(sid) def transport(self, sid): """Return the name of the transport used by the client. diff --git a/tests/asyncio/test_asyncio_client.py b/tests/asyncio/test_asyncio_client.py index 27a3298..ecf88bc 100644 --- a/tests/asyncio/test_asyncio_client.py +++ b/tests/asyncio/test_asyncio_client.py @@ -411,28 +411,51 @@ class TestAsyncClient(unittest.TestCase): def test_handle_disconnect(self): c = asyncio_client.AsyncClient() + c.connected = True c._trigger_event = AsyncMock() _run(c._handle_disconnect('/')) c._trigger_event.mock.assert_called_once_with( 'disconnect', namespace='/') + self.assertFalse(c.connected) + _run(c._handle_disconnect('/')) + self.assertEqual(c._trigger_event.mock.call_count, 1) def test_handle_disconnect_namespace(self): c = asyncio_client.AsyncClient() + c.connected = True c.namespaces = ['/foo', '/bar'] c._trigger_event = AsyncMock() _run(c._handle_disconnect('/foo')) c._trigger_event.mock.assert_called_once_with( 'disconnect', namespace='/foo') self.assertEqual(c.namespaces, ['/bar']) + self.assertTrue(c.connected) def test_handle_disconnect_unknown_namespace(self): c = asyncio_client.AsyncClient() + c.connected = True c.namespaces = ['/foo', '/bar'] c._trigger_event = AsyncMock() _run(c._handle_disconnect('/baz')) c._trigger_event.mock.assert_called_once_with( 'disconnect', namespace='/baz') self.assertEqual(c.namespaces, ['/foo', '/bar']) + self.assertTrue(c.connected) + + def test_handle_disconnect_all_namespaces(self): + c = asyncio_client.AsyncClient() + c.connected = True + c.namespaces = ['/foo', '/bar'] + c._trigger_event = AsyncMock() + _run(c._handle_disconnect('/')) + c._trigger_event.mock.assert_any_call( + 'disconnect', namespace='/') + c._trigger_event.mock.assert_any_call( + 'disconnect', namespace='/foo') + c._trigger_event.mock.assert_any_call( + 'disconnect', namespace='/bar') + self.assertEqual(c.namespaces, []) + self.assertFalse(c.connected) def test_handle_event(self): c = asyncio_client.AsyncClient() @@ -519,15 +542,27 @@ class TestAsyncClient(unittest.TestCase): def test_handle_error(self): c = asyncio_client.AsyncClient() + c.connected = True + c.namespaces = ['/foo', '/bar'] + c._handle_error('/') + self.assertEqual(c.namespaces, []) + self.assertFalse(c.connected) + + def test_handle_error_namespace(self): + c = asyncio_client.AsyncClient() + c.connected = True c.namespaces = ['/foo', '/bar'] c._handle_error('/bar') self.assertEqual(c.namespaces, ['/foo']) + self.assertTrue(c.connected) def test_handle_error_unknown_namespace(self): c = asyncio_client.AsyncClient() + c.connected = True c.namespaces = ['/foo', '/bar'] c._handle_error('/baz') self.assertEqual(c.namespaces, ['/foo', '/bar']) + self.assertTrue(c.connected) def test_trigger_event(self): c = asyncio_client.AsyncClient() @@ -556,6 +591,19 @@ class TestAsyncClient(unittest.TestCase): _run(c._trigger_event('foo', '/', 1, '2')) self.assertEqual(result, [1, '2']) + def test_trigger_event_unknown_namespace(self): + c = asyncio_client.AsyncClient() + result = [] + + class MyNamespace(asyncio_namespace.AsyncClientNamespace): + def on_foo(self, a, b): + result.append(a) + result.append(b) + + c.register_namespace(MyNamespace('/')) + _run(c._trigger_event('foo', '/bar', 1, '2')) + self.assertEqual(result, []) + @mock.patch('asyncio.wait_for', new_callable=AsyncMock, side_effect=asyncio.TimeoutError) @mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5]) @@ -663,6 +711,7 @@ class TestAsyncClient(unittest.TestCase): def test_eio_disconnect(self): c = asyncio_client.AsyncClient() + c.connected = True c._trigger_event = AsyncMock() c.sid = 'foo' c.eio.state = 'connected' @@ -670,9 +719,11 @@ class TestAsyncClient(unittest.TestCase): c._trigger_event.mock.assert_called_once_with( 'disconnect', namespace='/') self.assertIsNone(c.sid) + self.assertFalse(c.connected) def test_eio_disconnect_namespaces(self): c = asyncio_client.AsyncClient() + c.connected = True c.namespaces = ['/foo', '/bar'] c._trigger_event = AsyncMock() c.sid = 'foo' @@ -682,6 +733,7 @@ class TestAsyncClient(unittest.TestCase): c._trigger_event.mock.assert_any_call('disconnect', namespace='/bar') c._trigger_event.mock.assert_any_call('disconnect', namespace='/') self.assertIsNone(c.sid) + self.assertFalse(c.connected) def test_eio_disconnect_reconnect(self): c = asyncio_client.AsyncClient(reconnection=True) diff --git a/tests/asyncio/test_asyncio_server.py b/tests/asyncio/test_asyncio_server.py index e9a2c66..02e13ec 100644 --- a/tests/asyncio/test_asyncio_server.py +++ b/tests/asyncio/test_asyncio_server.py @@ -596,27 +596,33 @@ class TestAsyncServer(unittest.TestCase): def test_disconnect(self, eio): eio.return_value.send = AsyncMock() + eio.return_value.disconnect = AsyncMock() s = asyncio_server.AsyncServer() _run(s._handle_eio_connect('123', 'environ')) _run(s.disconnect('123')) s.eio.send.mock.assert_any_call('123', '1', binary=False) + s.eio.disconnect.mock.assert_called_once_with('123') def test_disconnect_namespace(self, eio): eio.return_value.send = AsyncMock() + eio.return_value.disconnect = AsyncMock() s = asyncio_server.AsyncServer() _run(s._handle_eio_connect('123', 'environ')) _run(s._handle_eio_message('123', '0/foo')) _run(s.disconnect('123', namespace='/foo')) s.eio.send.mock.assert_any_call('123', '1/foo', binary=False) + s.eio.disconnect.mock.assert_not_called() def test_disconnect_twice(self, eio): eio.return_value.send = AsyncMock() + eio.return_value.disconnect = AsyncMock() s = asyncio_server.AsyncServer() _run(s._handle_eio_connect('123', 'environ')) _run(s.disconnect('123')) calls = s.eio.send.mock.call_count _run(s.disconnect('123')) self.assertEqual(calls, s.eio.send.mock.call_count) + self.assertEqual(s.eio.disconnect.mock.call_count, 1) def test_disconnect_twice_namespace(self, eio): eio.return_value.send = AsyncMock() diff --git a/tests/common/test_client.py b/tests/common/test_client.py index da21f44..fea7e10 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -536,27 +536,47 @@ class TestClient(unittest.TestCase): def test_handle_disconnect(self): c = client.Client() + c.connected = True c._trigger_event = mock.MagicMock() c._handle_disconnect('/') c._trigger_event.assert_called_once_with('disconnect', namespace='/') + self.assertFalse(c.connected) + c._handle_disconnect('/') + self.assertEqual(c._trigger_event.call_count, 1) def test_handle_disconnect_namespace(self): c = client.Client() + c.connected = True c.namespaces = ['/foo', '/bar'] c._trigger_event = mock.MagicMock() c._handle_disconnect('/foo') c._trigger_event.assert_called_once_with('disconnect', namespace='/foo') self.assertEqual(c.namespaces, ['/bar']) + self.assertTrue(c.connected) def test_handle_disconnect_unknown_namespace(self): c = client.Client() + c.connected = True c.namespaces = ['/foo', '/bar'] c._trigger_event = mock.MagicMock() c._handle_disconnect('/baz') c._trigger_event.assert_called_once_with('disconnect', namespace='/baz') self.assertEqual(c.namespaces, ['/foo', '/bar']) + self.assertTrue(c.connected) + + def test_handle_disconnect_all_namespaces(self): + c = client.Client() + c.connected = True + c.namespaces = ['/foo', '/bar'] + c._trigger_event = mock.MagicMock() + c._handle_disconnect('/') + c._trigger_event.assert_any_call('disconnect', namespace='/') + c._trigger_event.assert_any_call('disconnect', namespace='/foo') + c._trigger_event.assert_any_call('disconnect', namespace='/bar') + self.assertEqual(c.namespaces, []) + self.assertFalse(c.connected) def test_handle_event(self): c = client.Client() @@ -630,15 +650,27 @@ class TestClient(unittest.TestCase): def test_handle_error(self): c = client.Client() + c.connected = True + c.namespaces = ['/foo', '/bar'] + c._handle_error('/') + self.assertEqual(c.namespaces, []) + self.assertFalse(c.connected) + + def test_handle_error_namespace(self): + c = client.Client() + c.connected = True c.namespaces = ['/foo', '/bar'] c._handle_error('/bar') self.assertEqual(c.namespaces, ['/foo']) + self.assertTrue(c.connected) def test_handle_error_unknown_namespace(self): c = client.Client() + c.connected = True c.namespaces = ['/foo', '/bar'] c._handle_error('/baz') self.assertEqual(c.namespaces, ['/foo', '/bar']) + self.assertTrue(c.connected) def test_trigger_event(self): c = client.Client() @@ -667,6 +699,19 @@ class TestClient(unittest.TestCase): c._trigger_event('foo', '/', 1, '2') self.assertEqual(result, [1, '2']) + def test_trigger_event_unknown_namespace(self): + c = client.Client() + result = [] + + class MyNamespace(namespace.ClientNamespace): + def on_foo(self, a, b): + result.append(a) + result.append(b) + + c.register_namespace(MyNamespace('/')) + c._trigger_event('foo', '/bar', 1, '2') + self.assertEqual(result, []) + @mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5]) def test_handle_reconnect(self, random): c = client.Client() @@ -774,6 +819,7 @@ class TestClient(unittest.TestCase): def test_eio_disconnect(self): c = client.Client() + c.connected = True c._trigger_event = mock.MagicMock() c.start_background_task = mock.MagicMock() c.sid = 'foo' @@ -781,9 +827,11 @@ class TestClient(unittest.TestCase): c._handle_eio_disconnect() c._trigger_event.assert_called_once_with('disconnect', namespace='/') self.assertIsNone(c.sid) + self.assertFalse(c.connected) def test_eio_disconnect_namespaces(self): c = client.Client() + c.connected = True c.namespaces = ['/foo', '/bar'] c._trigger_event = mock.MagicMock() c.start_background_task = mock.MagicMock() @@ -794,6 +842,7 @@ class TestClient(unittest.TestCase): c._trigger_event.assert_any_call('disconnect', namespace='/bar') c._trigger_event.assert_any_call('disconnect', namespace='/') self.assertIsNone(c.sid) + self.assertFalse(c.connected) def test_eio_disconnect_reconnect(self): c = client.Client(reconnection=True)