Browse Source

Allow functions to be used for URL, headers and auth data in client connection (Fixes #588)

pull/683/head
Miguel Grinberg 4 years ago
parent
commit
7d2e7f7eb3
No known key found for this signature in database GPG Key ID: 36848B262DF5F06C
  1. 29
      socketio/asyncio_client.py
  2. 28
      socketio/client.py
  3. 47
      tests/asyncio/test_asyncio_client.py
  4. 42
      tests/common/test_client.py

29
socketio/asyncio_client.py

@ -68,12 +68,19 @@ class AsyncClient(client.Client):
"""Connect to a Socket.IO server.
:param url: The URL of the Socket.IO server. It can include custom
query string parameters if required by the server.
query string parameters if required by the server. If a
function is provided, the client will invoke it to obtain
the URL each time a connection or reconnection is
attempted.
:param headers: A dictionary with custom headers to send with the
connection request.
connection request. If a function is provided, the
client will invoke it to obtain the headers dictionary
each time a connection or reconnection is attempted.
:param auth: Authentication data passed to the server with the
connection request, normally a dictionary with one or
more string key/value pairs.
more string key/value pairs. If a function is provided,
the client will invoke it to obtain the authentication
data each time a connection or reconnection is attempted.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
@ -124,8 +131,10 @@ class AsyncClient(client.Client):
self._connect_event = self.eio.create_event()
else:
self._connect_event.clear()
real_url = await self._get_real_value(self.connection_url)
real_headers = await self._get_real_value(self.connection_headers)
try:
await self.eio.connect(url, headers=headers,
await self.eio.connect(real_url, headers=real_headers,
transports=transports,
engineio_path=socketio_path)
except engineio.exceptions.ConnectionError as exc:
@ -320,6 +329,15 @@ class AsyncClient(client.Client):
"""
return await self.eio.sleep(seconds)
async def _get_real_value(self, value):
"""Return the actual value, for parameters that can also be given as
callables."""
if not callable(value):
return value
if asyncio.iscoroutinefunction(value):
return await value()
return value()
async def _send_packet(self, pkt):
"""Send a Socket.IO packet to the server."""
encoded_packet = pkt.encode()
@ -462,9 +480,10 @@ class AsyncClient(client.Client):
"""Handle the Engine.IO connection event."""
self.logger.info('Engine.IO connection established')
self.sid = self.eio.sid
real_auth = await self._get_real_value(self.connection_auth)
for n in self.connection_namespaces:
await self._send_packet(packet.Packet(
packet.CONNECT, data=self.connection_auth, namespace=n))
packet.CONNECT, data=real_auth, namespace=n))
async def _handle_eio_message(self, data):
"""Dispatch Engine.IO messages."""

28
socketio/client.py

@ -240,12 +240,19 @@ class Client(object):
"""Connect to a Socket.IO server.
:param url: The URL of the Socket.IO server. It can include custom
query string parameters if required by the server.
query string parameters if required by the server. If a
function is provided, the client will invoke it to obtain
the URL each time a connection or reconnection is
attempted.
:param headers: A dictionary with custom headers to send with the
connection request.
connection request. If a function is provided, the
client will invoke it to obtain the headers dictionary
each time a connection or reconnection is attempted.
:param auth: Authentication data passed to the server with the
connection request, normally a dictionary with one or
more string key/value pairs.
more string key/value pairs. If a function is provided,
the client will invoke it to obtain the authentication
data each time a connection or reconnection is attempted.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
@ -294,8 +301,11 @@ class Client(object):
self._connect_event = self.eio.create_event()
else:
self._connect_event.clear()
real_url = self._get_real_value(self.connection_url)
real_headers = self._get_real_value(self.connection_headers)
try:
self.eio.connect(url, headers=headers, transports=transports,
self.eio.connect(real_url, headers=real_headers,
transports=transports,
engineio_path=socketio_path)
except engineio.exceptions.ConnectionError as exc:
self._trigger_event(
@ -490,6 +500,13 @@ class Client(object):
"""
return self.eio.sleep(seconds)
def _get_real_value(self, value):
"""Return the actual value, for parameters that can also be given as
callables."""
if not callable(value):
return value
return value()
def _send_packet(self, pkt):
"""Send a Socket.IO packet to the server."""
encoded_packet = pkt.encode()
@ -628,9 +645,10 @@ class Client(object):
"""Handle the Engine.IO connection event."""
self.logger.info('Engine.IO connection established')
self.sid = self.eio.sid
real_auth = self._get_real_value(self.connection_auth)
for n in self.connection_namespaces:
self._send_packet(packet.Packet(
packet.CONNECT, data=self.connection_auth, namespace=n))
packet.CONNECT, data=real_auth, namespace=n))
def _handle_eio_message(self, data):
"""Dispatch Engine.IO messages."""

47
tests/asyncio/test_asyncio_client.py

@ -75,6 +75,30 @@ class TestAsyncClient(unittest.TestCase):
engineio_path='path',
)
def test_connect_functions(self):
async def headers():
return 'headers'
c = asyncio_client.AsyncClient()
c.eio.connect = AsyncMock()
_run(
c.connect(
lambda: 'url',
headers=headers,
auth='auth',
transports='transports',
namespaces=['/foo', '/', '/bar'],
socketio_path='path',
wait=False,
)
)
c.eio.connect.mock.assert_called_once_with(
'url',
headers='headers',
transports='transports',
engineio_path='path',
)
def test_connect_one_namespace(self):
c = asyncio_client.AsyncClient()
c.eio.connect = AsyncMock()
@ -960,6 +984,29 @@ class TestAsyncClient(unittest.TestCase):
== expected_packet.encode()
)
def test_handle_eio_connect_function(self):
c = asyncio_client.AsyncClient()
c.connection_namespaces = ['/', '/foo']
c.connection_auth = lambda: 'auth'
c._send_packet = AsyncMock()
c.eio.sid = 'foo'
assert c.sid is None
_run(c._handle_eio_connect())
assert c.sid == 'foo'
assert c._send_packet.mock.call_count == 2
expected_packet = packet.Packet(
packet.CONNECT, data='auth', namespace='/')
assert (
c._send_packet.mock.call_args_list[0][0][0].encode()
== expected_packet.encode()
)
expected_packet = packet.Packet(
packet.CONNECT, data='auth', namespace='/foo')
assert (
c._send_packet.mock.call_args_list[1][0][0].encode()
== expected_packet.encode()
)
def test_handle_eio_message(self):
c = asyncio_client.AsyncClient()
c._handle_connect = AsyncMock()

42
tests/common/test_client.py

@ -173,6 +173,25 @@ class TestClient(unittest.TestCase):
engineio_path='path',
)
def test_connect_functions(self):
c = client.Client()
c.eio.connect = mock.MagicMock()
c.connect(
lambda: 'url',
headers=lambda: 'headers',
auth='auth',
transports='transports',
namespaces=['/foo', '/', '/bar'],
socketio_path='path',
wait=False,
)
c.eio.connect.assert_called_once_with(
'url',
headers='headers',
transports='transports',
engineio_path='path',
)
def test_connect_one_namespace(self):
c = client.Client()
c.eio.connect = mock.MagicMock()
@ -1030,6 +1049,29 @@ class TestClient(unittest.TestCase):
== expected_packet.encode()
)
def test_handle_eio_connect_function(self):
c = client.Client()
c.connection_namespaces = ['/', '/foo']
c.connection_auth = lambda: 'auth'
c._send_packet = mock.MagicMock()
c.eio.sid = 'foo'
assert c.sid is None
c._handle_eio_connect()
assert c.sid == 'foo'
assert c._send_packet.call_count == 2
expected_packet = packet.Packet(
packet.CONNECT, data='auth', namespace='/')
assert (
c._send_packet.call_args_list[0][0][0].encode()
== expected_packet.encode()
)
expected_packet = packet.Packet(
packet.CONNECT, data='auth', namespace='/foo')
assert (
c._send_packet.call_args_list[1][0][0].encode()
== expected_packet.encode()
)
def test_handle_eio_message(self):
c = client.Client()
c._handle_connect = mock.MagicMock()

Loading…
Cancel
Save