From 4da6d74f56a58e68b0aef08212347097dd73cda9 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Mon, 15 Feb 2021 12:27:29 +0000 Subject: [PATCH] Added wait argument to client's connect method (Fixes #634) --- socketio/asyncio_client.py | 42 ++++++++++-- socketio/client.py | 38 +++++++++-- tests/asyncio/test_asyncio_client.py | 97 ++++++++++++++++++++++++++++ tests/common/test_client.py | 89 +++++++++++++++++++++++++ 4 files changed, 256 insertions(+), 10 deletions(-) diff --git a/socketio/asyncio_client.py b/socketio/asyncio_client.py index 29f6e44..d89c627 100644 --- a/socketio/asyncio_client.py +++ b/socketio/asyncio_client.py @@ -63,7 +63,8 @@ class AsyncClient(client.Client): return True async def connect(self, url, headers={}, transports=None, - namespaces=None, socketio_path='socket.io'): + namespaces=None, socketio_path='socket.io', wait=True, + wait_timeout=1): """Connect to a Socket.IO server. :param url: The URL of the Socket.IO server. It can include custom @@ -80,18 +81,26 @@ class AsyncClient(client.Client): :param socketio_path: The endpoint where the Socket.IO server is installed. The default value is appropriate for most cases. + :param wait: if set to ``True`` (the default) the call only returns + when all the namespaces are connected. If set to + ``False``, the call returns as soon as the Engine.IO + transport is connected, and the namespaces will connect + in the background. + :param wait_timeout: How long the client should wait for the + connection. The default is 1 second. This + argument is only considered when ``wait`` is set + to ``True``. Note: this method is a coroutine. - Note: The connection mechannism occurs in the background and will - complete at some point after this function returns. The connection - will be established when the ``connect`` event is invoked. - Example usage:: sio = socketio.AsyncClient() sio.connect('http://localhost:5000') """ + if self.connected: + raise exceptions.ConnectionError('Already connected') + self.connection_url = url self.connection_headers = headers self.connection_transports = transports @@ -106,6 +115,11 @@ class AsyncClient(client.Client): elif isinstance(namespaces, str): namespaces = [namespaces] self.connection_namespaces = namespaces + self.namespaces = {} + if self._connect_event is None: + self._connect_event = self.eio.create_event() + else: + self._connect_event.clear() try: await self.eio.connect(url, headers=headers, transports=transports, @@ -115,6 +129,22 @@ class AsyncClient(client.Client): 'connect_error', '/', exc.args[1] if len(exc.args) > 1 else exc.args[0]) raise exceptions.ConnectionError(exc.args[0]) from None + + if wait: + try: + while True: + await asyncio.wait_for(self._connect_event.wait(), + wait_timeout) + self._connect_event.clear() + if set(self.namespaces) == set(self.connection_namespaces): + break + except asyncio.TimeoutError: + pass + if set(self.namespaces) != set(self.connection_namespaces): + await self.disconnect() + raise exceptions.ConnectionError( + 'One or more namespaces failed to connect') + self.connected = True async def wait(self): @@ -301,6 +331,7 @@ class AsyncClient(client.Client): self.logger.info('Namespace {} is connected'.format(namespace)) self.namespaces[namespace] = (data or {}).get('sid', self.sid) await self._trigger_event('connect', namespace=namespace) + self._connect_event.set() async def _handle_disconnect(self, namespace): if not self.connected: @@ -355,6 +386,7 @@ class AsyncClient(client.Client): elif not isinstance(data, (tuple, list)): data = (data,) await self._trigger_event('connect_error', namespace, *data) + self._connect_event.set() if namespace in self.namespaces: del self.namespaces[namespace] if namespace == '/': diff --git a/socketio/client.py b/socketio/client.py index 8eaa6ce..80c2e31 100644 --- a/socketio/client.py +++ b/socketio/client.py @@ -131,6 +131,7 @@ class Client(object): self.namespace_handlers = {} self.callbacks = {} self._binary_packet = None + self._connect_event = None self._reconnect_task = None self._reconnect_abort = None @@ -233,7 +234,8 @@ class Client(object): namespace_handler def connect(self, url, headers={}, transports=None, - namespaces=None, socketio_path='socket.io'): + namespaces=None, socketio_path='socket.io', wait=True, + wait_timeout=1): """Connect to a Socket.IO server. :param url: The URL of the Socket.IO server. It can include custom @@ -250,16 +252,24 @@ class Client(object): :param socketio_path: The endpoint where the Socket.IO server is installed. The default value is appropriate for most cases. - - Note: The connection mechannism occurs in the background and will - complete at some point after this function returns. The connection - will be established when the ``connect`` event is invoked. + :param wait: if set to ``True`` (the default) the call only returns + when all the namespaces are connected. If set to + ``False``, the call returns as soon as the Engine.IO + transport is connected, and the namespaces will connect + in the background. + :param wait_timeout: How long the client should wait for the + connection. The default is 1 second. This + argument is only considered when ``wait`` is set + to ``True``. Example usage:: sio = socketio.Client() sio.connect('http://localhost:5000') """ + if self.connected: + raise exceptions.ConnectionError('Already connected') + self.connection_url = url self.connection_headers = headers self.connection_transports = transports @@ -274,6 +284,11 @@ class Client(object): elif isinstance(namespaces, str): namespaces = [namespaces] self.connection_namespaces = namespaces + self.namespaces = {} + if self._connect_event is None: + self._connect_event = self.eio.create_event() + else: + self._connect_event.clear() try: self.eio.connect(url, headers=headers, transports=transports, engineio_path=socketio_path) @@ -282,6 +297,17 @@ class Client(object): 'connect_error', '/', exc.args[1] if len(exc.args) > 1 else exc.args[0]) raise exceptions.ConnectionError(exc.args[0]) from None + + if wait: + while self._connect_event.wait(timeout=wait_timeout): + self._connect_event.clear() + if set(self.namespaces) == set(self.connection_namespaces): + break + if set(self.namespaces) != set(self.connection_namespaces): + self.disconnect() + raise exceptions.ConnectionError( + 'One or more namespaces failed to connect') + self.connected = True def wait(self): @@ -483,6 +509,7 @@ class Client(object): self.logger.info('Namespace {} is connected'.format(namespace)) self.namespaces[namespace] = (data or {}).get('sid', self.sid) self._trigger_event('connect', namespace=namespace) + self._connect_event.set() def _handle_disconnect(self, namespace): if not self.connected: @@ -534,6 +561,7 @@ class Client(object): elif not isinstance(data, (tuple, list)): data = (data,) self._trigger_event('connect_error', namespace, *data) + self._connect_event.set() if namespace in self.namespaces: del self.namespaces[namespace] if namespace == '/': diff --git a/tests/asyncio/test_asyncio_client.py b/tests/asyncio/test_asyncio_client.py index a96ca0e..38193fd 100644 --- a/tests/asyncio/test_asyncio_client.py +++ b/tests/asyncio/test_asyncio_client.py @@ -58,6 +58,7 @@ class TestAsyncClient(unittest.TestCase): transports='transports', namespaces=['/foo', '/', '/bar'], socketio_path='path', + wait=False, ) ) assert c.connection_url == 'url' @@ -82,6 +83,7 @@ class TestAsyncClient(unittest.TestCase): transports='transports', namespaces='/foo', socketio_path='path', + wait=False, ) ) assert c.connection_url == 'url' @@ -107,6 +109,7 @@ class TestAsyncClient(unittest.TestCase): headers='headers', transports='transports', socketio_path='path', + wait=False, ) ) assert c.connection_url == 'url' @@ -131,6 +134,7 @@ class TestAsyncClient(unittest.TestCase): headers='headers', transports='transports', socketio_path='path', + wait=False, ) ) assert c.connection_url == 'url' @@ -159,9 +163,86 @@ class TestAsyncClient(unittest.TestCase): headers='headers', transports='transports', socketio_path='path', + wait=False, ) ) + def test_connect_twice(self): + c = asyncio_client.AsyncClient() + c.eio.connect = AsyncMock() + _run( + c.connect( + 'url', + wait=False, + ) + ) + with pytest.raises(exceptions.ConnectionError): + _run( + c.connect( + 'url', + wait=False, + ) + ) + + def test_connect_wait_single_namespace(self): + c = asyncio_client.AsyncClient() + c.eio.connect = AsyncMock() + c._connect_event = mock.MagicMock() + + async def mock_connect(): + c.namespaces = {'/': '123'} + return True + + c._connect_event.wait = mock_connect + _run( + c.connect( + 'url', + wait=True, + wait_timeout=0.01, + ) + ) + assert c.connected is True + + def test_connect_wait_two_namespaces(self): + c = asyncio_client.AsyncClient() + c.eio.connect = AsyncMock() + c._connect_event = mock.MagicMock() + + async def mock_connect(): + if c.namespaces == {}: + c.namespaces = {'/bar': '123'} + return True + elif c.namespaces == {'/bar': '123'}: + c.namespaces = {'/bar': '123', '/foo': '456'} + return True + return False + + c._connect_event.wait = mock_connect + _run( + c.connect( + 'url', + namespaces=['/foo', '/bar'], + wait=True, + wait_timeout=0.01, + ) + ) + assert c.connected is True + assert c.namespaces == {'/bar': '123', '/foo': '456'} + + def test_connect_timeout(self): + c = asyncio_client.AsyncClient() + c.eio.connect = AsyncMock() + c.disconnect = AsyncMock() + with pytest.raises(exceptions.ConnectionError): + _run( + c.connect( + 'url', + wait=True, + wait_timeout=0.01, + ) + ) + c.disconnect.mock.assert_called_once_with() + def test_wait_no_reconnect(self): c = asyncio_client.AsyncClient() c.eio.wait = AsyncMock() @@ -486,29 +567,35 @@ class TestAsyncClient(unittest.TestCase): def test_handle_connect(self): c = asyncio_client.AsyncClient() + c._connect_event = mock.MagicMock() c._trigger_event = AsyncMock() c._send_packet = AsyncMock() _run(c._handle_connect('/', {'sid': '123'})) + c._connect_event.set.assert_called_once_with() c._trigger_event.mock.assert_called_once_with('connect', namespace='/') c._send_packet.mock.assert_not_called() def test_handle_connect_with_namespaces(self): c = asyncio_client.AsyncClient() c.namespaces = {'/foo': '1', '/bar': '2'} + c._connect_event = mock.MagicMock() c._trigger_event = AsyncMock() c._send_packet = AsyncMock() _run(c._handle_connect('/', {'sid': '3'})) + c._connect_event.set.assert_called_once_with() c._trigger_event.mock.assert_called_once_with('connect', namespace='/') assert c.namespaces == {'/': '3', '/foo': '1', '/bar': '2'} def test_handle_connect_namespace(self): c = asyncio_client.AsyncClient() c.namespaces = {'/foo': '1'} + c._connect_event = mock.MagicMock() c._trigger_event = AsyncMock() c._send_packet = AsyncMock() _run(c._handle_connect('/foo', {'sid': '123'})) _run(c._handle_connect('/bar', {'sid': '2'})) assert c._trigger_event.mock.call_count == 1 + c._connect_event.set.assert_called_once_with() c._trigger_event.mock.assert_called_once_with( 'connect', namespace='/bar') assert c.namespaces == {'/foo': '1', '/bar': '2'} @@ -658,11 +745,13 @@ class TestAsyncClient(unittest.TestCase): def test_handle_error(self): c = asyncio_client.AsyncClient() c.connected = True + c._connect_event = mock.MagicMock() c._trigger_event = AsyncMock() c.namespaces = {'/foo': '1', '/bar': '2'} _run(c._handle_error('/', 'error')) assert c.namespaces == {} assert not c.connected + c._connect_event.set.assert_called_once_with() c._trigger_event.mock.assert_called_once_with( 'connect_error', '/', 'error' ) @@ -670,21 +759,25 @@ class TestAsyncClient(unittest.TestCase): def test_handle_error_with_no_arguments(self): c = asyncio_client.AsyncClient() c.connected = True + c._connect_event = mock.MagicMock() c._trigger_event = AsyncMock() c.namespaces = {'/foo': '1', '/bar': '2'} _run(c._handle_error('/', None)) assert c.namespaces == {} assert not c.connected + c._connect_event.set.assert_called_once_with() c._trigger_event.mock.assert_called_once_with('connect_error', '/') def test_handle_error_namespace(self): c = asyncio_client.AsyncClient() c.connected = True c.namespaces = {'/foo': '1', '/bar': '2'} + c._connect_event = mock.MagicMock() c._trigger_event = AsyncMock() _run(c._handle_error('/bar', ['error', 'message'])) assert c.namespaces == {'/foo': '1'} assert c.connected + c._connect_event.set.assert_called_once_with() c._trigger_event.mock.assert_called_once_with( 'connect_error', '/bar', 'error', 'message' ) @@ -693,19 +786,23 @@ class TestAsyncClient(unittest.TestCase): c = asyncio_client.AsyncClient() c.connected = True c.namespaces = {'/foo': '1', '/bar': '2'} + c._connect_event = mock.MagicMock() c._trigger_event = AsyncMock() _run(c._handle_error('/bar', None)) assert c.namespaces == {'/foo': '1'} assert c.connected + c._connect_event.set.assert_called_once_with() c._trigger_event.mock.assert_called_once_with('connect_error', '/bar') def test_handle_error_unknown_namespace(self): c = asyncio_client.AsyncClient() c.connected = True c.namespaces = {'/foo': '1', '/bar': '2'} + c._connect_event = mock.MagicMock() _run(c._handle_error('/baz', 'error')) assert c.namespaces == {'/foo': '1', '/bar': '2'} assert c.connected + c._connect_event.set.assert_called_once_with() def test_trigger_event(self): c = asyncio_client.AsyncClient() diff --git a/tests/common/test_client.py b/tests/common/test_client.py index 9f9481c..61da8d9 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -157,6 +157,7 @@ class TestClient(unittest.TestCase): transports='transports', namespaces=['/foo', '/', '/bar'], socketio_path='path', + wait=False, ) assert c.connection_url == 'url' assert c.connection_headers == 'headers' @@ -179,6 +180,7 @@ class TestClient(unittest.TestCase): transports='transports', namespaces='/foo', socketio_path='path', + wait=False, ) assert c.connection_url == 'url' assert c.connection_headers == 'headers' @@ -202,6 +204,7 @@ class TestClient(unittest.TestCase): headers='headers', transports='transports', socketio_path='path', + wait=False, ) assert c.connection_url == 'url' assert c.connection_headers == 'headers' @@ -224,6 +227,7 @@ class TestClient(unittest.TestCase): headers='headers', transports='transports', socketio_path='path', + wait=False, ) assert c.connection_url == 'url' assert c.connection_headers == 'headers' @@ -250,8 +254,77 @@ class TestClient(unittest.TestCase): headers='headers', transports='transports', socketio_path='path', + wait=False, ) + def test_connect_twice(self): + c = client.Client() + c.eio.connect = mock.MagicMock() + c.connect( + 'url', + wait=False, + ) + with pytest.raises(exceptions.ConnectionError): + c.connect( + 'url', + wait=False, + ) + + def test_connect_wait_single_namespace(self): + c = client.Client() + c.eio.connect = mock.MagicMock() + c._connect_event = mock.MagicMock() + + def mock_connect(timeout): + assert timeout == 0.01 + c.namespaces = {'/': '123'} + return True + + c._connect_event.wait = mock_connect + c.connect( + 'url', + wait=True, + wait_timeout=0.01, + ) + assert c.connected is True + + def test_connect_wait_two_namespaces(self): + c = client.Client() + c.eio.connect = mock.MagicMock() + c._connect_event = mock.MagicMock() + + def mock_connect(timeout): + assert timeout == 0.01 + if c.namespaces == {}: + c.namespaces = {'/bar': '123'} + return True + elif c.namespaces == {'/bar': '123'}: + c.namespaces = {'/bar': '123', '/foo': '456'} + return True + return False + + c._connect_event.wait = mock_connect + c.connect( + 'url', + namespaces=['/foo', '/bar'], + wait=True, + wait_timeout=0.01, + ) + assert c.connected is True + assert c.namespaces == {'/bar': '123', '/foo': '456'} + + def test_connect_timeout(self): + c = client.Client() + c.eio.connect = mock.MagicMock() + c.disconnect = mock.MagicMock() + with pytest.raises(exceptions.ConnectionError): + c.connect( + 'url', + wait=True, + wait_timeout=0.01, + ) + c.disconnect.assert_called_once_with() + def test_wait_no_reconnect(self): c = client.Client() c.eio.wait = mock.MagicMock() @@ -602,30 +675,36 @@ class TestClient(unittest.TestCase): def test_handle_connect(self): c = client.Client() + c._connect_event = mock.MagicMock() c._trigger_event = mock.MagicMock() c._send_packet = mock.MagicMock() c._handle_connect('/', {'sid': '123'}) assert c.namespaces == {'/': '123'} + c._connect_event.set.assert_called_once_with() c._trigger_event.assert_called_once_with('connect', namespace='/') c._send_packet.assert_not_called() def test_handle_connect_with_namespaces(self): c = client.Client() c.namespaces = {'/foo': '1', '/bar': '2'} + c._connect_event = mock.MagicMock() c._trigger_event = mock.MagicMock() c._send_packet = mock.MagicMock() c._handle_connect('/', {'sid': '3'}) + c._connect_event.set.assert_called_once_with() c._trigger_event.assert_called_once_with('connect', namespace='/') assert c.namespaces == {'/': '3', '/foo': '1', '/bar': '2'} def test_handle_connect_namespace(self): c = client.Client() c.namespaces = {'/foo': '1'} + c._connect_event = mock.MagicMock() c._trigger_event = mock.MagicMock() c._send_packet = mock.MagicMock() c._handle_connect('/foo', {'sid': '123'}) c._handle_connect('/bar', {'sid': '2'}) assert c._trigger_event.call_count == 1 + c._connect_event.set.assert_called_once_with() c._trigger_event.assert_called_once_with('connect', namespace='/bar') assert c.namespaces == {'/foo': '1', '/bar': '2'} @@ -762,30 +841,36 @@ class TestClient(unittest.TestCase): c = client.Client() c.connected = True c.namespaces = {'/foo': '1', '/bar': '2'} + c._connect_event = mock.MagicMock() c._trigger_event = mock.MagicMock() c._handle_error('/', 'error') assert c.namespaces == {} assert not c.connected + c._connect_event.set.assert_called_once_with() c._trigger_event.assert_called_once_with('connect_error', '/', 'error') def test_handle_error_with_no_arguments(self): c = client.Client() c.connected = True c.namespaces = {'/foo': '1', '/bar': '2'} + c._connect_event = mock.MagicMock() c._trigger_event = mock.MagicMock() c._handle_error('/', None) assert c.namespaces == {} assert not c.connected + c._connect_event.set.assert_called_once_with() c._trigger_event.assert_called_once_with('connect_error', '/') def test_handle_error_namespace(self): c = client.Client() c.connected = True c.namespaces = {'/foo': '1', '/bar': '2'} + c._connect_event = mock.MagicMock() c._trigger_event = mock.MagicMock() c._handle_error('/bar', ['error', 'message']) assert c.namespaces == {'/foo': '1'} assert c.connected + c._connect_event.set.assert_called_once_with() c._trigger_event.assert_called_once_with( 'connect_error', '/bar', 'error', 'message' ) @@ -794,19 +879,23 @@ class TestClient(unittest.TestCase): c = client.Client() c.connected = True c.namespaces = {'/foo': '1', '/bar': '2'} + c._connect_event = mock.MagicMock() c._trigger_event = mock.MagicMock() c._handle_error('/bar', None) assert c.namespaces == {'/foo': '1'} assert c.connected + c._connect_event.set.assert_called_once_with() c._trigger_event.assert_called_once_with('connect_error', '/bar') def test_handle_error_unknown_namespace(self): c = client.Client() c.connected = True c.namespaces = {'/foo': '1', '/bar': '2'} + c._connect_event = mock.MagicMock() c._handle_error('/baz', 'error') assert c.namespaces == {'/foo': '1', '/bar': '2'} assert c.connected + c._connect_event.set.assert_called_once_with() def test_trigger_event(self): c = client.Client()