diff --git a/socketio/asyncio_client.py b/socketio/asyncio_client.py index 03aa931..157cd87 100644 --- a/socketio/asyncio_client.py +++ b/socketio/asyncio_client.py @@ -68,12 +68,19 @@ class AsyncClient(client.Client): """Connect to a Socket.IO server. :param url: The URL of the Socket.IO server. It can include custom - query string parameters if required by the server. + query string parameters if required by the server. If a + function is provided, the client will invoke it to obtain + the URL each time a connection or reconnection is + attempted. :param headers: A dictionary with custom headers to send with the - connection request. + connection request. If a function is provided, the + client will invoke it to obtain the headers dictionary + each time a connection or reconnection is attempted. :param auth: Authentication data passed to the server with the connection request, normally a dictionary with one or - more string key/value pairs. + more string key/value pairs. If a function is provided, + the client will invoke it to obtain the authentication + data each time a connection or reconnection is attempted. :param transports: The list of allowed transports. Valid transports are ``'polling'`` and ``'websocket'``. If not given, the polling transport is connected first, @@ -124,8 +131,10 @@ class AsyncClient(client.Client): self._connect_event = self.eio.create_event() else: self._connect_event.clear() + real_url = await self._get_real_value(self.connection_url) + real_headers = await self._get_real_value(self.connection_headers) try: - await self.eio.connect(url, headers=headers, + await self.eio.connect(real_url, headers=real_headers, transports=transports, engineio_path=socketio_path) except engineio.exceptions.ConnectionError as exc: @@ -320,6 +329,15 @@ class AsyncClient(client.Client): """ return await self.eio.sleep(seconds) + async def _get_real_value(self, value): + """Return the actual value, for parameters that can also be given as + callables.""" + if not callable(value): + return value + if asyncio.iscoroutinefunction(value): + return await value() + return value() + async def _send_packet(self, pkt): """Send a Socket.IO packet to the server.""" encoded_packet = pkt.encode() @@ -462,9 +480,10 @@ class AsyncClient(client.Client): """Handle the Engine.IO connection event.""" self.logger.info('Engine.IO connection established') self.sid = self.eio.sid + real_auth = await self._get_real_value(self.connection_auth) for n in self.connection_namespaces: await self._send_packet(packet.Packet( - packet.CONNECT, data=self.connection_auth, namespace=n)) + packet.CONNECT, data=real_auth, namespace=n)) async def _handle_eio_message(self, data): """Dispatch Engine.IO messages.""" diff --git a/socketio/client.py b/socketio/client.py index 24bf72e..84fa7e0 100644 --- a/socketio/client.py +++ b/socketio/client.py @@ -240,12 +240,19 @@ class Client(object): """Connect to a Socket.IO server. :param url: The URL of the Socket.IO server. It can include custom - query string parameters if required by the server. + query string parameters if required by the server. If a + function is provided, the client will invoke it to obtain + the URL each time a connection or reconnection is + attempted. :param headers: A dictionary with custom headers to send with the - connection request. + connection request. If a function is provided, the + client will invoke it to obtain the headers dictionary + each time a connection or reconnection is attempted. :param auth: Authentication data passed to the server with the connection request, normally a dictionary with one or - more string key/value pairs. + more string key/value pairs. If a function is provided, + the client will invoke it to obtain the authentication + data each time a connection or reconnection is attempted. :param transports: The list of allowed transports. Valid transports are ``'polling'`` and ``'websocket'``. If not given, the polling transport is connected first, @@ -294,8 +301,11 @@ class Client(object): self._connect_event = self.eio.create_event() else: self._connect_event.clear() + real_url = self._get_real_value(self.connection_url) + real_headers = self._get_real_value(self.connection_headers) try: - self.eio.connect(url, headers=headers, transports=transports, + self.eio.connect(real_url, headers=real_headers, + transports=transports, engineio_path=socketio_path) except engineio.exceptions.ConnectionError as exc: self._trigger_event( @@ -490,6 +500,13 @@ class Client(object): """ return self.eio.sleep(seconds) + def _get_real_value(self, value): + """Return the actual value, for parameters that can also be given as + callables.""" + if not callable(value): + return value + return value() + def _send_packet(self, pkt): """Send a Socket.IO packet to the server.""" encoded_packet = pkt.encode() @@ -628,9 +645,10 @@ class Client(object): """Handle the Engine.IO connection event.""" self.logger.info('Engine.IO connection established') self.sid = self.eio.sid + real_auth = self._get_real_value(self.connection_auth) for n in self.connection_namespaces: self._send_packet(packet.Packet( - packet.CONNECT, data=self.connection_auth, namespace=n)) + packet.CONNECT, data=real_auth, namespace=n)) def _handle_eio_message(self, data): """Dispatch Engine.IO messages.""" diff --git a/tests/asyncio/test_asyncio_client.py b/tests/asyncio/test_asyncio_client.py index e222bb8..38abcdd 100644 --- a/tests/asyncio/test_asyncio_client.py +++ b/tests/asyncio/test_asyncio_client.py @@ -75,6 +75,30 @@ class TestAsyncClient(unittest.TestCase): engineio_path='path', ) + def test_connect_functions(self): + async def headers(): + return 'headers' + + c = asyncio_client.AsyncClient() + c.eio.connect = AsyncMock() + _run( + c.connect( + lambda: 'url', + headers=headers, + auth='auth', + transports='transports', + namespaces=['/foo', '/', '/bar'], + socketio_path='path', + wait=False, + ) + ) + c.eio.connect.mock.assert_called_once_with( + 'url', + headers='headers', + transports='transports', + engineio_path='path', + ) + def test_connect_one_namespace(self): c = asyncio_client.AsyncClient() c.eio.connect = AsyncMock() @@ -960,6 +984,29 @@ class TestAsyncClient(unittest.TestCase): == expected_packet.encode() ) + def test_handle_eio_connect_function(self): + c = asyncio_client.AsyncClient() + c.connection_namespaces = ['/', '/foo'] + c.connection_auth = lambda: 'auth' + c._send_packet = AsyncMock() + c.eio.sid = 'foo' + assert c.sid is None + _run(c._handle_eio_connect()) + assert c.sid == 'foo' + assert c._send_packet.mock.call_count == 2 + expected_packet = packet.Packet( + packet.CONNECT, data='auth', namespace='/') + assert ( + c._send_packet.mock.call_args_list[0][0][0].encode() + == expected_packet.encode() + ) + expected_packet = packet.Packet( + packet.CONNECT, data='auth', namespace='/foo') + assert ( + c._send_packet.mock.call_args_list[1][0][0].encode() + == expected_packet.encode() + ) + def test_handle_eio_message(self): c = asyncio_client.AsyncClient() c._handle_connect = AsyncMock() diff --git a/tests/common/test_client.py b/tests/common/test_client.py index 2f85e6c..9eb8211 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -173,6 +173,25 @@ class TestClient(unittest.TestCase): engineio_path='path', ) + def test_connect_functions(self): + c = client.Client() + c.eio.connect = mock.MagicMock() + c.connect( + lambda: 'url', + headers=lambda: 'headers', + auth='auth', + transports='transports', + namespaces=['/foo', '/', '/bar'], + socketio_path='path', + wait=False, + ) + c.eio.connect.assert_called_once_with( + 'url', + headers='headers', + transports='transports', + engineio_path='path', + ) + def test_connect_one_namespace(self): c = client.Client() c.eio.connect = mock.MagicMock() @@ -1030,6 +1049,29 @@ class TestClient(unittest.TestCase): == expected_packet.encode() ) + def test_handle_eio_connect_function(self): + c = client.Client() + c.connection_namespaces = ['/', '/foo'] + c.connection_auth = lambda: 'auth' + c._send_packet = mock.MagicMock() + c.eio.sid = 'foo' + assert c.sid is None + c._handle_eio_connect() + assert c.sid == 'foo' + assert c._send_packet.call_count == 2 + expected_packet = packet.Packet( + packet.CONNECT, data='auth', namespace='/') + assert ( + c._send_packet.call_args_list[0][0][0].encode() + == expected_packet.encode() + ) + expected_packet = packet.Packet( + packet.CONNECT, data='auth', namespace='/foo') + assert ( + c._send_packet.call_args_list[1][0][0].encode() + == expected_packet.encode() + ) + def test_handle_eio_message(self): c = client.Client() c._handle_connect = mock.MagicMock()