From 811e044a46b7d6e4d94bf870e59d0cd8187850d3 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Sun, 19 May 2024 20:30:59 +0100 Subject: [PATCH] New shutdown() method added to the client (Fixes #1333) --- src/socketio/async_client.py | 24 +++++++++++++-- src/socketio/client.py | 14 +++++++++ tests/async/test_client.py | 60 ++++++++++++++++++++++++++++++++++++ tests/common/test_client.py | 57 ++++++++++++++++++++++++++++++++++ 4 files changed, 153 insertions(+), 2 deletions(-) diff --git a/src/socketio/async_client.py b/src/socketio/async_client.py index 9184d02..d7e8c27 100644 --- a/src/socketio/async_client.py +++ b/src/socketio/async_client.py @@ -318,6 +318,21 @@ class AsyncClient(base_client.BaseClient): namespace=n)) await self.eio.disconnect(abort=True) + async def shutdown(self): + """Stop the client. + + If the client is connected to a server, it is disconnected. If the + client is attempting to reconnect to server, the reconnection attempts + are stopped. If the client is not connected to a server and is not + attempting to reconnect, then this function does nothing. + """ + if self.connected: + await self.disconnect() + elif self._reconnect_task: # pragma: no branch + self._reconnect_abort.set() + print(self._reconnect_task) + await self._reconnect_task + def start_background_task(self, target, *args, **kwargs): """Start a background task using the appropriate async model. @@ -467,15 +482,20 @@ class AsyncClient(base_client.BaseClient): self.logger.info( 'Connection failed, new attempt in {:.02f} seconds'.format( delay)) + abort = False try: await asyncio.wait_for(self._reconnect_abort.wait(), delay) + abort = True + except asyncio.TimeoutError: + pass + except asyncio.CancelledError: # pragma: no cover + abort = True + if abort: self.logger.info('Reconnect task aborted') for n in self.connection_namespaces: await self._trigger_event('__disconnect_final', namespace=n) break - except (asyncio.TimeoutError, asyncio.CancelledError): - pass attempt_count += 1 try: await self.connect(self.connection_url, diff --git a/src/socketio/client.py b/src/socketio/client.py index 905bb1e..e5150e9 100644 --- a/src/socketio/client.py +++ b/src/socketio/client.py @@ -298,6 +298,20 @@ class Client(base_client.BaseClient): packet.DISCONNECT, namespace=n)) self.eio.disconnect(abort=True) + def shutdown(self): + """Stop the client. + + If the client is connected to a server, it is disconnected. If the + client is attempting to reconnect to server, the reconnection attempts + are stopped. If the client is not connected to a server and is not + attempting to reconnect, then this function does nothing. + """ + if self.connected: + self.disconnect() + elif self._reconnect_task: # pragma: no branch + self._reconnect_abort.set() + self._reconnect_task.join() + def start_background_task(self, target, *args, **kwargs): """Start a background task using the appropriate async model. diff --git a/tests/async/test_client.py b/tests/async/test_client.py index 8b8f97a..82c66e1 100644 --- a/tests/async/test_client.py +++ b/tests/async/test_client.py @@ -990,6 +990,66 @@ class TestAsyncClient(unittest.TestCase): c._trigger_event.mock.assert_called_once_with('__disconnect_final', namespace='/') + def test_shutdown_disconnect(self): + c = async_client.AsyncClient() + c.connected = True + c.namespaces = {'/': '1'} + c._trigger_event = AsyncMock() + c._send_packet = AsyncMock() + c.eio = mock.MagicMock() + c.eio.disconnect = AsyncMock() + c.eio.state = 'connected' + _run(c.shutdown()) + assert c._trigger_event.mock.call_count == 0 + assert c._send_packet.mock.call_count == 1 + expected_packet = packet.Packet(packet.DISCONNECT, namespace='/') + assert ( + c._send_packet.mock.call_args_list[0][0][0].encode() + == expected_packet.encode() + ) + c.eio.disconnect.mock.assert_called_once_with(abort=True) + + def test_shutdown_disconnect_namespaces(self): + c = async_client.AsyncClient() + c.connected = True + c.namespaces = {'/foo': '1', '/bar': '2'} + c._trigger_event = AsyncMock() + c._send_packet = AsyncMock() + c.eio = mock.MagicMock() + c.eio.disconnect = AsyncMock() + c.eio.state = 'connected' + _run(c.shutdown()) + assert c._trigger_event.mock.call_count == 0 + assert c._send_packet.mock.call_count == 2 + expected_packet = packet.Packet(packet.DISCONNECT, namespace='/foo') + assert ( + c._send_packet.mock.call_args_list[0][0][0].encode() + == expected_packet.encode() + ) + expected_packet = packet.Packet(packet.DISCONNECT, namespace='/bar') + assert ( + c._send_packet.mock.call_args_list[1][0][0].encode() + == expected_packet.encode() + ) + + @mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5]) + def test_shutdown_reconnect(self, random): + c = async_client.AsyncClient() + c.connection_namespaces = ['/'] + c._reconnect_task = AsyncMock()() + c._trigger_event = AsyncMock() + c.connect = AsyncMock(side_effect=exceptions.ConnectionError) + + async def r(): + task = c.start_background_task(c._handle_reconnect) + await asyncio.sleep(0.1) + await c.shutdown() + await task + + _run(r()) + c._trigger_event.mock.assert_called_once_with('__disconnect_final', + namespace='/') + def test_handle_eio_connect(self): c = async_client.AsyncClient() c.connection_namespaces = ['/', '/foo'] diff --git a/tests/common/test_client.py b/tests/common/test_client.py index d1fcf8e..a9117bc 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -1,4 +1,5 @@ import logging +import time import unittest from unittest import mock @@ -636,6 +637,7 @@ class TestClient(unittest.TestCase): def test_disconnect_namespaces(self): c = client.Client() + c.connected = True c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.MagicMock() c._send_packet = mock.MagicMock() @@ -1128,6 +1130,61 @@ class TestClient(unittest.TestCase): c._trigger_event.assert_called_once_with('__disconnect_final', namespace='/') + def test_shutdown_disconnect(self): + c = client.Client() + c.connected = True + c.namespaces = {'/': '1'} + c._trigger_event = mock.MagicMock() + c._send_packet = mock.MagicMock() + c.eio = mock.MagicMock() + c.eio.state = 'connected' + c.shutdown() + assert c._trigger_event.call_count == 0 + assert c._send_packet.call_count == 1 + expected_packet = packet.Packet(packet.DISCONNECT, namespace='/') + assert ( + c._send_packet.call_args_list[0][0][0].encode() + == expected_packet.encode() + ) + c.eio.disconnect.assert_called_once_with(abort=True) + + def test_shutdown_disconnect_namespaces(self): + c = client.Client() + c.connected = True + c.namespaces = {'/foo': '1', '/bar': '2'} + c._trigger_event = mock.MagicMock() + c._send_packet = mock.MagicMock() + c.eio = mock.MagicMock() + c.eio.state = 'connected' + c.shutdown() + assert c._trigger_event.call_count == 0 + assert c._send_packet.call_count == 2 + expected_packet = packet.Packet(packet.DISCONNECT, namespace='/foo') + assert ( + c._send_packet.call_args_list[0][0][0].encode() + == expected_packet.encode() + ) + expected_packet = packet.Packet(packet.DISCONNECT, namespace='/bar') + assert ( + c._send_packet.call_args_list[1][0][0].encode() + == expected_packet.encode() + ) + + @mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5]) + def test_shutdown_reconnect(self, random): + c = client.Client() + c.connection_namespaces = ['/'] + c._reconnect_task = mock.MagicMock() + c._trigger_event = mock.MagicMock() + c.connect = mock.MagicMock(side_effect=exceptions.ConnectionError) + task = c.start_background_task(c._handle_reconnect) + time.sleep(0.1) + c.shutdown() + task.join() + c._trigger_event.assert_called_once_with('__disconnect_final', + namespace='/') + assert c._reconnect_task.join.called_once_with() + def test_handle_eio_connect(self): c = client.Client() c.connection_namespaces = ['/', '/foo']