Browse Source

Added wait argument to client's connect method (Fixes #634)

pull/657/head
Miguel Grinberg 4 years ago
parent
commit
4da6d74f56
No known key found for this signature in database GPG Key ID: 36848B262DF5F06C
  1. 42
      socketio/asyncio_client.py
  2. 38
      socketio/client.py
  3. 97
      tests/asyncio/test_asyncio_client.py
  4. 89
      tests/common/test_client.py

42
socketio/asyncio_client.py

@ -63,7 +63,8 @@ class AsyncClient(client.Client):
return True
async def connect(self, url, headers={}, transports=None,
namespaces=None, socketio_path='socket.io'):
namespaces=None, socketio_path='socket.io', wait=True,
wait_timeout=1):
"""Connect to a Socket.IO server.
:param url: The URL of the Socket.IO server. It can include custom
@ -80,18 +81,26 @@ class AsyncClient(client.Client):
:param socketio_path: The endpoint where the Socket.IO server is
installed. The default value is appropriate for
most cases.
:param wait: if set to ``True`` (the default) the call only returns
when all the namespaces are connected. If set to
``False``, the call returns as soon as the Engine.IO
transport is connected, and the namespaces will connect
in the background.
:param wait_timeout: How long the client should wait for the
connection. The default is 1 second. This
argument is only considered when ``wait`` is set
to ``True``.
Note: this method is a coroutine.
Note: The connection mechannism occurs in the background and will
complete at some point after this function returns. The connection
will be established when the ``connect`` event is invoked.
Example usage::
sio = socketio.AsyncClient()
sio.connect('http://localhost:5000')
"""
if self.connected:
raise exceptions.ConnectionError('Already connected')
self.connection_url = url
self.connection_headers = headers
self.connection_transports = transports
@ -106,6 +115,11 @@ class AsyncClient(client.Client):
elif isinstance(namespaces, str):
namespaces = [namespaces]
self.connection_namespaces = namespaces
self.namespaces = {}
if self._connect_event is None:
self._connect_event = self.eio.create_event()
else:
self._connect_event.clear()
try:
await self.eio.connect(url, headers=headers,
transports=transports,
@ -115,6 +129,22 @@ class AsyncClient(client.Client):
'connect_error', '/',
exc.args[1] if len(exc.args) > 1 else exc.args[0])
raise exceptions.ConnectionError(exc.args[0]) from None
if wait:
try:
while True:
await asyncio.wait_for(self._connect_event.wait(),
wait_timeout)
self._connect_event.clear()
if set(self.namespaces) == set(self.connection_namespaces):
break
except asyncio.TimeoutError:
pass
if set(self.namespaces) != set(self.connection_namespaces):
await self.disconnect()
raise exceptions.ConnectionError(
'One or more namespaces failed to connect')
self.connected = True
async def wait(self):
@ -301,6 +331,7 @@ class AsyncClient(client.Client):
self.logger.info('Namespace {} is connected'.format(namespace))
self.namespaces[namespace] = (data or {}).get('sid', self.sid)
await self._trigger_event('connect', namespace=namespace)
self._connect_event.set()
async def _handle_disconnect(self, namespace):
if not self.connected:
@ -355,6 +386,7 @@ class AsyncClient(client.Client):
elif not isinstance(data, (tuple, list)):
data = (data,)
await self._trigger_event('connect_error', namespace, *data)
self._connect_event.set()
if namespace in self.namespaces:
del self.namespaces[namespace]
if namespace == '/':

38
socketio/client.py

@ -131,6 +131,7 @@ class Client(object):
self.namespace_handlers = {}
self.callbacks = {}
self._binary_packet = None
self._connect_event = None
self._reconnect_task = None
self._reconnect_abort = None
@ -233,7 +234,8 @@ class Client(object):
namespace_handler
def connect(self, url, headers={}, transports=None,
namespaces=None, socketio_path='socket.io'):
namespaces=None, socketio_path='socket.io', wait=True,
wait_timeout=1):
"""Connect to a Socket.IO server.
:param url: The URL of the Socket.IO server. It can include custom
@ -250,16 +252,24 @@ class Client(object):
:param socketio_path: The endpoint where the Socket.IO server is
installed. The default value is appropriate for
most cases.
Note: The connection mechannism occurs in the background and will
complete at some point after this function returns. The connection
will be established when the ``connect`` event is invoked.
:param wait: if set to ``True`` (the default) the call only returns
when all the namespaces are connected. If set to
``False``, the call returns as soon as the Engine.IO
transport is connected, and the namespaces will connect
in the background.
:param wait_timeout: How long the client should wait for the
connection. The default is 1 second. This
argument is only considered when ``wait`` is set
to ``True``.
Example usage::
sio = socketio.Client()
sio.connect('http://localhost:5000')
"""
if self.connected:
raise exceptions.ConnectionError('Already connected')
self.connection_url = url
self.connection_headers = headers
self.connection_transports = transports
@ -274,6 +284,11 @@ class Client(object):
elif isinstance(namespaces, str):
namespaces = [namespaces]
self.connection_namespaces = namespaces
self.namespaces = {}
if self._connect_event is None:
self._connect_event = self.eio.create_event()
else:
self._connect_event.clear()
try:
self.eio.connect(url, headers=headers, transports=transports,
engineio_path=socketio_path)
@ -282,6 +297,17 @@ class Client(object):
'connect_error', '/',
exc.args[1] if len(exc.args) > 1 else exc.args[0])
raise exceptions.ConnectionError(exc.args[0]) from None
if wait:
while self._connect_event.wait(timeout=wait_timeout):
self._connect_event.clear()
if set(self.namespaces) == set(self.connection_namespaces):
break
if set(self.namespaces) != set(self.connection_namespaces):
self.disconnect()
raise exceptions.ConnectionError(
'One or more namespaces failed to connect')
self.connected = True
def wait(self):
@ -483,6 +509,7 @@ class Client(object):
self.logger.info('Namespace {} is connected'.format(namespace))
self.namespaces[namespace] = (data or {}).get('sid', self.sid)
self._trigger_event('connect', namespace=namespace)
self._connect_event.set()
def _handle_disconnect(self, namespace):
if not self.connected:
@ -534,6 +561,7 @@ class Client(object):
elif not isinstance(data, (tuple, list)):
data = (data,)
self._trigger_event('connect_error', namespace, *data)
self._connect_event.set()
if namespace in self.namespaces:
del self.namespaces[namespace]
if namespace == '/':

97
tests/asyncio/test_asyncio_client.py

@ -58,6 +58,7 @@ class TestAsyncClient(unittest.TestCase):
transports='transports',
namespaces=['/foo', '/', '/bar'],
socketio_path='path',
wait=False,
)
)
assert c.connection_url == 'url'
@ -82,6 +83,7 @@ class TestAsyncClient(unittest.TestCase):
transports='transports',
namespaces='/foo',
socketio_path='path',
wait=False,
)
)
assert c.connection_url == 'url'
@ -107,6 +109,7 @@ class TestAsyncClient(unittest.TestCase):
headers='headers',
transports='transports',
socketio_path='path',
wait=False,
)
)
assert c.connection_url == 'url'
@ -131,6 +134,7 @@ class TestAsyncClient(unittest.TestCase):
headers='headers',
transports='transports',
socketio_path='path',
wait=False,
)
)
assert c.connection_url == 'url'
@ -159,9 +163,86 @@ class TestAsyncClient(unittest.TestCase):
headers='headers',
transports='transports',
socketio_path='path',
wait=False,
)
)
def test_connect_twice(self):
c = asyncio_client.AsyncClient()
c.eio.connect = AsyncMock()
_run(
c.connect(
'url',
wait=False,
)
)
with pytest.raises(exceptions.ConnectionError):
_run(
c.connect(
'url',
wait=False,
)
)
def test_connect_wait_single_namespace(self):
c = asyncio_client.AsyncClient()
c.eio.connect = AsyncMock()
c._connect_event = mock.MagicMock()
async def mock_connect():
c.namespaces = {'/': '123'}
return True
c._connect_event.wait = mock_connect
_run(
c.connect(
'url',
wait=True,
wait_timeout=0.01,
)
)
assert c.connected is True
def test_connect_wait_two_namespaces(self):
c = asyncio_client.AsyncClient()
c.eio.connect = AsyncMock()
c._connect_event = mock.MagicMock()
async def mock_connect():
if c.namespaces == {}:
c.namespaces = {'/bar': '123'}
return True
elif c.namespaces == {'/bar': '123'}:
c.namespaces = {'/bar': '123', '/foo': '456'}
return True
return False
c._connect_event.wait = mock_connect
_run(
c.connect(
'url',
namespaces=['/foo', '/bar'],
wait=True,
wait_timeout=0.01,
)
)
assert c.connected is True
assert c.namespaces == {'/bar': '123', '/foo': '456'}
def test_connect_timeout(self):
c = asyncio_client.AsyncClient()
c.eio.connect = AsyncMock()
c.disconnect = AsyncMock()
with pytest.raises(exceptions.ConnectionError):
_run(
c.connect(
'url',
wait=True,
wait_timeout=0.01,
)
)
c.disconnect.mock.assert_called_once_with()
def test_wait_no_reconnect(self):
c = asyncio_client.AsyncClient()
c.eio.wait = AsyncMock()
@ -486,29 +567,35 @@ class TestAsyncClient(unittest.TestCase):
def test_handle_connect(self):
c = asyncio_client.AsyncClient()
c._connect_event = mock.MagicMock()
c._trigger_event = AsyncMock()
c._send_packet = AsyncMock()
_run(c._handle_connect('/', {'sid': '123'}))
c._connect_event.set.assert_called_once_with()
c._trigger_event.mock.assert_called_once_with('connect', namespace='/')
c._send_packet.mock.assert_not_called()
def test_handle_connect_with_namespaces(self):
c = asyncio_client.AsyncClient()
c.namespaces = {'/foo': '1', '/bar': '2'}
c._connect_event = mock.MagicMock()
c._trigger_event = AsyncMock()
c._send_packet = AsyncMock()
_run(c._handle_connect('/', {'sid': '3'}))
c._connect_event.set.assert_called_once_with()
c._trigger_event.mock.assert_called_once_with('connect', namespace='/')
assert c.namespaces == {'/': '3', '/foo': '1', '/bar': '2'}
def test_handle_connect_namespace(self):
c = asyncio_client.AsyncClient()
c.namespaces = {'/foo': '1'}
c._connect_event = mock.MagicMock()
c._trigger_event = AsyncMock()
c._send_packet = AsyncMock()
_run(c._handle_connect('/foo', {'sid': '123'}))
_run(c._handle_connect('/bar', {'sid': '2'}))
assert c._trigger_event.mock.call_count == 1
c._connect_event.set.assert_called_once_with()
c._trigger_event.mock.assert_called_once_with(
'connect', namespace='/bar')
assert c.namespaces == {'/foo': '1', '/bar': '2'}
@ -658,11 +745,13 @@ class TestAsyncClient(unittest.TestCase):
def test_handle_error(self):
c = asyncio_client.AsyncClient()
c.connected = True
c._connect_event = mock.MagicMock()
c._trigger_event = AsyncMock()
c.namespaces = {'/foo': '1', '/bar': '2'}
_run(c._handle_error('/', 'error'))
assert c.namespaces == {}
assert not c.connected
c._connect_event.set.assert_called_once_with()
c._trigger_event.mock.assert_called_once_with(
'connect_error', '/', 'error'
)
@ -670,21 +759,25 @@ class TestAsyncClient(unittest.TestCase):
def test_handle_error_with_no_arguments(self):
c = asyncio_client.AsyncClient()
c.connected = True
c._connect_event = mock.MagicMock()
c._trigger_event = AsyncMock()
c.namespaces = {'/foo': '1', '/bar': '2'}
_run(c._handle_error('/', None))
assert c.namespaces == {}
assert not c.connected
c._connect_event.set.assert_called_once_with()
c._trigger_event.mock.assert_called_once_with('connect_error', '/')
def test_handle_error_namespace(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = {'/foo': '1', '/bar': '2'}
c._connect_event = mock.MagicMock()
c._trigger_event = AsyncMock()
_run(c._handle_error('/bar', ['error', 'message']))
assert c.namespaces == {'/foo': '1'}
assert c.connected
c._connect_event.set.assert_called_once_with()
c._trigger_event.mock.assert_called_once_with(
'connect_error', '/bar', 'error', 'message'
)
@ -693,19 +786,23 @@ class TestAsyncClient(unittest.TestCase):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = {'/foo': '1', '/bar': '2'}
c._connect_event = mock.MagicMock()
c._trigger_event = AsyncMock()
_run(c._handle_error('/bar', None))
assert c.namespaces == {'/foo': '1'}
assert c.connected
c._connect_event.set.assert_called_once_with()
c._trigger_event.mock.assert_called_once_with('connect_error', '/bar')
def test_handle_error_unknown_namespace(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = {'/foo': '1', '/bar': '2'}
c._connect_event = mock.MagicMock()
_run(c._handle_error('/baz', 'error'))
assert c.namespaces == {'/foo': '1', '/bar': '2'}
assert c.connected
c._connect_event.set.assert_called_once_with()
def test_trigger_event(self):
c = asyncio_client.AsyncClient()

89
tests/common/test_client.py

@ -157,6 +157,7 @@ class TestClient(unittest.TestCase):
transports='transports',
namespaces=['/foo', '/', '/bar'],
socketio_path='path',
wait=False,
)
assert c.connection_url == 'url'
assert c.connection_headers == 'headers'
@ -179,6 +180,7 @@ class TestClient(unittest.TestCase):
transports='transports',
namespaces='/foo',
socketio_path='path',
wait=False,
)
assert c.connection_url == 'url'
assert c.connection_headers == 'headers'
@ -202,6 +204,7 @@ class TestClient(unittest.TestCase):
headers='headers',
transports='transports',
socketio_path='path',
wait=False,
)
assert c.connection_url == 'url'
assert c.connection_headers == 'headers'
@ -224,6 +227,7 @@ class TestClient(unittest.TestCase):
headers='headers',
transports='transports',
socketio_path='path',
wait=False,
)
assert c.connection_url == 'url'
assert c.connection_headers == 'headers'
@ -250,8 +254,77 @@ class TestClient(unittest.TestCase):
headers='headers',
transports='transports',
socketio_path='path',
wait=False,
)
def test_connect_twice(self):
c = client.Client()
c.eio.connect = mock.MagicMock()
c.connect(
'url',
wait=False,
)
with pytest.raises(exceptions.ConnectionError):
c.connect(
'url',
wait=False,
)
def test_connect_wait_single_namespace(self):
c = client.Client()
c.eio.connect = mock.MagicMock()
c._connect_event = mock.MagicMock()
def mock_connect(timeout):
assert timeout == 0.01
c.namespaces = {'/': '123'}
return True
c._connect_event.wait = mock_connect
c.connect(
'url',
wait=True,
wait_timeout=0.01,
)
assert c.connected is True
def test_connect_wait_two_namespaces(self):
c = client.Client()
c.eio.connect = mock.MagicMock()
c._connect_event = mock.MagicMock()
def mock_connect(timeout):
assert timeout == 0.01
if c.namespaces == {}:
c.namespaces = {'/bar': '123'}
return True
elif c.namespaces == {'/bar': '123'}:
c.namespaces = {'/bar': '123', '/foo': '456'}
return True
return False
c._connect_event.wait = mock_connect
c.connect(
'url',
namespaces=['/foo', '/bar'],
wait=True,
wait_timeout=0.01,
)
assert c.connected is True
assert c.namespaces == {'/bar': '123', '/foo': '456'}
def test_connect_timeout(self):
c = client.Client()
c.eio.connect = mock.MagicMock()
c.disconnect = mock.MagicMock()
with pytest.raises(exceptions.ConnectionError):
c.connect(
'url',
wait=True,
wait_timeout=0.01,
)
c.disconnect.assert_called_once_with()
def test_wait_no_reconnect(self):
c = client.Client()
c.eio.wait = mock.MagicMock()
@ -602,30 +675,36 @@ class TestClient(unittest.TestCase):
def test_handle_connect(self):
c = client.Client()
c._connect_event = mock.MagicMock()
c._trigger_event = mock.MagicMock()
c._send_packet = mock.MagicMock()
c._handle_connect('/', {'sid': '123'})
assert c.namespaces == {'/': '123'}
c._connect_event.set.assert_called_once_with()
c._trigger_event.assert_called_once_with('connect', namespace='/')
c._send_packet.assert_not_called()
def test_handle_connect_with_namespaces(self):
c = client.Client()
c.namespaces = {'/foo': '1', '/bar': '2'}
c._connect_event = mock.MagicMock()
c._trigger_event = mock.MagicMock()
c._send_packet = mock.MagicMock()
c._handle_connect('/', {'sid': '3'})
c._connect_event.set.assert_called_once_with()
c._trigger_event.assert_called_once_with('connect', namespace='/')
assert c.namespaces == {'/': '3', '/foo': '1', '/bar': '2'}
def test_handle_connect_namespace(self):
c = client.Client()
c.namespaces = {'/foo': '1'}
c._connect_event = mock.MagicMock()
c._trigger_event = mock.MagicMock()
c._send_packet = mock.MagicMock()
c._handle_connect('/foo', {'sid': '123'})
c._handle_connect('/bar', {'sid': '2'})
assert c._trigger_event.call_count == 1
c._connect_event.set.assert_called_once_with()
c._trigger_event.assert_called_once_with('connect', namespace='/bar')
assert c.namespaces == {'/foo': '1', '/bar': '2'}
@ -762,30 +841,36 @@ class TestClient(unittest.TestCase):
c = client.Client()
c.connected = True
c.namespaces = {'/foo': '1', '/bar': '2'}
c._connect_event = mock.MagicMock()
c._trigger_event = mock.MagicMock()
c._handle_error('/', 'error')
assert c.namespaces == {}
assert not c.connected
c._connect_event.set.assert_called_once_with()
c._trigger_event.assert_called_once_with('connect_error', '/', 'error')
def test_handle_error_with_no_arguments(self):
c = client.Client()
c.connected = True
c.namespaces = {'/foo': '1', '/bar': '2'}
c._connect_event = mock.MagicMock()
c._trigger_event = mock.MagicMock()
c._handle_error('/', None)
assert c.namespaces == {}
assert not c.connected
c._connect_event.set.assert_called_once_with()
c._trigger_event.assert_called_once_with('connect_error', '/')
def test_handle_error_namespace(self):
c = client.Client()
c.connected = True
c.namespaces = {'/foo': '1', '/bar': '2'}
c._connect_event = mock.MagicMock()
c._trigger_event = mock.MagicMock()
c._handle_error('/bar', ['error', 'message'])
assert c.namespaces == {'/foo': '1'}
assert c.connected
c._connect_event.set.assert_called_once_with()
c._trigger_event.assert_called_once_with(
'connect_error', '/bar', 'error', 'message'
)
@ -794,19 +879,23 @@ class TestClient(unittest.TestCase):
c = client.Client()
c.connected = True
c.namespaces = {'/foo': '1', '/bar': '2'}
c._connect_event = mock.MagicMock()
c._trigger_event = mock.MagicMock()
c._handle_error('/bar', None)
assert c.namespaces == {'/foo': '1'}
assert c.connected
c._connect_event.set.assert_called_once_with()
c._trigger_event.assert_called_once_with('connect_error', '/bar')
def test_handle_error_unknown_namespace(self):
c = client.Client()
c.connected = True
c.namespaces = {'/foo': '1', '/bar': '2'}
c._connect_event = mock.MagicMock()
c._handle_error('/baz', 'error')
assert c.namespaces == {'/foo': '1', '/bar': '2'}
assert c.connected
c._connect_event.set.assert_called_once_with()
def test_trigger_event(self):
c = client.Client()

Loading…
Cancel
Save