Browse Source

Disconnect Engine.IO connection when server disconnects a client (https://github.com/miguelgrinberg/Flask-SocketIO/issues/1017)

pull/348/head
Miguel Grinberg 6 years ago
parent
commit
516a2958f4
No known key found for this signature in database GPG Key ID: 36848B262DF5F06C
  1. 22
      socketio/asyncio_client.py
  2. 2
      socketio/asyncio_server.py
  3. 24
      socketio/client.py
  4. 2
      socketio/server.py
  5. 52
      tests/asyncio/test_asyncio_client.py
  6. 6
      tests/asyncio/test_asyncio_server.py
  7. 49
      tests/common/test_client.py

22
socketio/asyncio_client.py

@ -102,6 +102,7 @@ class AsyncClient(client.Client):
engineio_path=socketio_path)
except engineio.exceptions.ConnectionError as exc:
six.raise_from(exceptions.ConnectionError(exc.args[0]), None)
self.connected = True
async def wait(self):
"""Wait until the connection with the server ends.
@ -232,6 +233,7 @@ class AsyncClient(client.Client):
namespace=n))
await self._send_packet(packet.Packet(
packet.DISCONNECT, namespace='/'))
self.connected = False
await self.eio.disconnect(abort=True)
def start_background_task(self, target, *args, **kwargs):
@ -286,10 +288,18 @@ class AsyncClient(client.Client):
self.namespaces.append(namespace)
async def _handle_disconnect(self, namespace):
if not self.connected:
return
namespace = namespace or '/'
if namespace == '/':
for n in self.namespaces:
await self._trigger_event('disconnect', namespace=n)
self.namespaces = []
await self._trigger_event('disconnect', namespace=namespace)
if namespace in self.namespaces:
self.namespaces.remove(namespace)
if namespace == '/':
self.connected = False
async def _handle_event(self, namespace, id, data):
namespace = namespace or '/'
@ -335,6 +345,9 @@ class AsyncClient(client.Client):
namespace))
if namespace in self.namespaces:
self.namespaces.remove(namespace)
if namespace == '/':
self.namespaces = []
self.connected = False
async def _trigger_event(self, event, namespace, *args):
"""Invoke an application event handler."""
@ -431,9 +444,12 @@ class AsyncClient(client.Client):
"""Handle the Engine.IO disconnection event."""
self.logger.info('Engine.IO connection dropped')
self._reconnect_abort.set()
for n in self.namespaces:
await self._trigger_event('disconnect', namespace=n)
await self._trigger_event('disconnect', namespace='/')
if self.connected:
for n in self.namespaces:
await self._trigger_event('disconnect', namespace=n)
await self._trigger_event('disconnect', namespace='/')
self.namespaces = []
self.connected = False
self.callbacks = {}
self._binary_packet = None
self.sid = None

2
socketio/asyncio_server.py

@ -312,6 +312,8 @@ class AsyncServer(server.Server):
namespace=namespace))
await self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace)
if namespace == '/':
await self.eio.disconnect(sid)
async def handle_request(self, *args, **kwargs):
"""Handle an HTTP request from the client.

24
socketio/client.py

@ -112,6 +112,7 @@ class Client(object):
self.socketio_path = None
self.sid = None
self.connected = False
self.namespaces = []
self.handlers = {}
self.namespace_handlers = {}
@ -261,6 +262,7 @@ class Client(object):
engineio_path=socketio_path)
except engineio.exceptions.ConnectionError as exc:
six.raise_from(exceptions.ConnectionError(exc.args[0]), None)
self.connected = True
def wait(self):
"""Wait until the connection with the server ends.
@ -377,6 +379,7 @@ class Client(object):
self._send_packet(packet.Packet(packet.DISCONNECT, namespace=n))
self._send_packet(packet.Packet(
packet.DISCONNECT, namespace='/'))
self.connected = False
self.eio.disconnect(abort=True)
def transport(self):
@ -445,10 +448,18 @@ class Client(object):
self.namespaces.append(namespace)
def _handle_disconnect(self, namespace):
if not self.connected:
return
namespace = namespace or '/'
if namespace == '/':
for n in self.namespaces:
self._trigger_event('disconnect', namespace=n)
self.namespaces = []
self._trigger_event('disconnect', namespace=namespace)
if namespace in self.namespaces:
self.namespaces.remove(namespace)
if namespace == '/':
self.connected = False
def _handle_event(self, namespace, id, data):
namespace = namespace or '/'
@ -490,6 +501,9 @@ class Client(object):
namespace))
if namespace in self.namespaces:
self.namespaces.remove(namespace)
if namespace == '/':
self.namespaces = []
self.connected = False
def _trigger_event(self, event, namespace, *args):
"""Invoke an application event handler."""
@ -516,7 +530,6 @@ class Client(object):
self.logger.info(
'Connection failed, new attempt in {:.02f} seconds'.format(
delay))
print('***', self._reconnect_abort.wait)
if self._reconnect_abort.wait(delay):
self.logger.info('Reconnect task aborted')
break
@ -576,9 +589,12 @@ class Client(object):
def _handle_eio_disconnect(self):
"""Handle the Engine.IO disconnection event."""
self.logger.info('Engine.IO connection dropped')
for n in self.namespaces:
self._trigger_event('disconnect', namespace=n)
self._trigger_event('disconnect', namespace='/')
if self.connected:
for n in self.namespaces:
self._trigger_event('disconnect', namespace=n)
self._trigger_event('disconnect', namespace='/')
self.namespaces = []
self.connected = False
self.callbacks = {}
self._binary_packet = None
self.sid = None

2
socketio/server.py

@ -504,6 +504,8 @@ class Server(object):
namespace=namespace))
self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace)
if namespace == '/':
self.eio.disconnect(sid)
def transport(self, sid):
"""Return the name of the transport used by the client.

52
tests/asyncio/test_asyncio_client.py

@ -411,28 +411,51 @@ class TestAsyncClient(unittest.TestCase):
def test_handle_disconnect(self):
c = asyncio_client.AsyncClient()
c.connected = True
c._trigger_event = AsyncMock()
_run(c._handle_disconnect('/'))
c._trigger_event.mock.assert_called_once_with(
'disconnect', namespace='/')
self.assertFalse(c.connected)
_run(c._handle_disconnect('/'))
self.assertEqual(c._trigger_event.mock.call_count, 1)
def test_handle_disconnect_namespace(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._trigger_event = AsyncMock()
_run(c._handle_disconnect('/foo'))
c._trigger_event.mock.assert_called_once_with(
'disconnect', namespace='/foo')
self.assertEqual(c.namespaces, ['/bar'])
self.assertTrue(c.connected)
def test_handle_disconnect_unknown_namespace(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._trigger_event = AsyncMock()
_run(c._handle_disconnect('/baz'))
c._trigger_event.mock.assert_called_once_with(
'disconnect', namespace='/baz')
self.assertEqual(c.namespaces, ['/foo', '/bar'])
self.assertTrue(c.connected)
def test_handle_disconnect_all_namespaces(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._trigger_event = AsyncMock()
_run(c._handle_disconnect('/'))
c._trigger_event.mock.assert_any_call(
'disconnect', namespace='/')
c._trigger_event.mock.assert_any_call(
'disconnect', namespace='/foo')
c._trigger_event.mock.assert_any_call(
'disconnect', namespace='/bar')
self.assertEqual(c.namespaces, [])
self.assertFalse(c.connected)
def test_handle_event(self):
c = asyncio_client.AsyncClient()
@ -519,15 +542,27 @@ class TestAsyncClient(unittest.TestCase):
def test_handle_error(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._handle_error('/')
self.assertEqual(c.namespaces, [])
self.assertFalse(c.connected)
def test_handle_error_namespace(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._handle_error('/bar')
self.assertEqual(c.namespaces, ['/foo'])
self.assertTrue(c.connected)
def test_handle_error_unknown_namespace(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._handle_error('/baz')
self.assertEqual(c.namespaces, ['/foo', '/bar'])
self.assertTrue(c.connected)
def test_trigger_event(self):
c = asyncio_client.AsyncClient()
@ -556,6 +591,19 @@ class TestAsyncClient(unittest.TestCase):
_run(c._trigger_event('foo', '/', 1, '2'))
self.assertEqual(result, [1, '2'])
def test_trigger_event_unknown_namespace(self):
c = asyncio_client.AsyncClient()
result = []
class MyNamespace(asyncio_namespace.AsyncClientNamespace):
def on_foo(self, a, b):
result.append(a)
result.append(b)
c.register_namespace(MyNamespace('/'))
_run(c._trigger_event('foo', '/bar', 1, '2'))
self.assertEqual(result, [])
@mock.patch('asyncio.wait_for', new_callable=AsyncMock,
side_effect=asyncio.TimeoutError)
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
@ -663,6 +711,7 @@ class TestAsyncClient(unittest.TestCase):
def test_eio_disconnect(self):
c = asyncio_client.AsyncClient()
c.connected = True
c._trigger_event = AsyncMock()
c.sid = 'foo'
c.eio.state = 'connected'
@ -670,9 +719,11 @@ class TestAsyncClient(unittest.TestCase):
c._trigger_event.mock.assert_called_once_with(
'disconnect', namespace='/')
self.assertIsNone(c.sid)
self.assertFalse(c.connected)
def test_eio_disconnect_namespaces(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._trigger_event = AsyncMock()
c.sid = 'foo'
@ -682,6 +733,7 @@ class TestAsyncClient(unittest.TestCase):
c._trigger_event.mock.assert_any_call('disconnect', namespace='/bar')
c._trigger_event.mock.assert_any_call('disconnect', namespace='/')
self.assertIsNone(c.sid)
self.assertFalse(c.connected)
def test_eio_disconnect_reconnect(self):
c = asyncio_client.AsyncClient(reconnection=True)

6
tests/asyncio/test_asyncio_server.py

@ -596,27 +596,33 @@ class TestAsyncServer(unittest.TestCase):
def test_disconnect(self, eio):
eio.return_value.send = AsyncMock()
eio.return_value.disconnect = AsyncMock()
s = asyncio_server.AsyncServer()
_run(s._handle_eio_connect('123', 'environ'))
_run(s.disconnect('123'))
s.eio.send.mock.assert_any_call('123', '1', binary=False)
s.eio.disconnect.mock.assert_called_once_with('123')
def test_disconnect_namespace(self, eio):
eio.return_value.send = AsyncMock()
eio.return_value.disconnect = AsyncMock()
s = asyncio_server.AsyncServer()
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0/foo'))
_run(s.disconnect('123', namespace='/foo'))
s.eio.send.mock.assert_any_call('123', '1/foo', binary=False)
s.eio.disconnect.mock.assert_not_called()
def test_disconnect_twice(self, eio):
eio.return_value.send = AsyncMock()
eio.return_value.disconnect = AsyncMock()
s = asyncio_server.AsyncServer()
_run(s._handle_eio_connect('123', 'environ'))
_run(s.disconnect('123'))
calls = s.eio.send.mock.call_count
_run(s.disconnect('123'))
self.assertEqual(calls, s.eio.send.mock.call_count)
self.assertEqual(s.eio.disconnect.mock.call_count, 1)
def test_disconnect_twice_namespace(self, eio):
eio.return_value.send = AsyncMock()

49
tests/common/test_client.py

@ -536,27 +536,47 @@ class TestClient(unittest.TestCase):
def test_handle_disconnect(self):
c = client.Client()
c.connected = True
c._trigger_event = mock.MagicMock()
c._handle_disconnect('/')
c._trigger_event.assert_called_once_with('disconnect', namespace='/')
self.assertFalse(c.connected)
c._handle_disconnect('/')
self.assertEqual(c._trigger_event.call_count, 1)
def test_handle_disconnect_namespace(self):
c = client.Client()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._trigger_event = mock.MagicMock()
c._handle_disconnect('/foo')
c._trigger_event.assert_called_once_with('disconnect',
namespace='/foo')
self.assertEqual(c.namespaces, ['/bar'])
self.assertTrue(c.connected)
def test_handle_disconnect_unknown_namespace(self):
c = client.Client()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._trigger_event = mock.MagicMock()
c._handle_disconnect('/baz')
c._trigger_event.assert_called_once_with('disconnect',
namespace='/baz')
self.assertEqual(c.namespaces, ['/foo', '/bar'])
self.assertTrue(c.connected)
def test_handle_disconnect_all_namespaces(self):
c = client.Client()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._trigger_event = mock.MagicMock()
c._handle_disconnect('/')
c._trigger_event.assert_any_call('disconnect', namespace='/')
c._trigger_event.assert_any_call('disconnect', namespace='/foo')
c._trigger_event.assert_any_call('disconnect', namespace='/bar')
self.assertEqual(c.namespaces, [])
self.assertFalse(c.connected)
def test_handle_event(self):
c = client.Client()
@ -630,15 +650,27 @@ class TestClient(unittest.TestCase):
def test_handle_error(self):
c = client.Client()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._handle_error('/')
self.assertEqual(c.namespaces, [])
self.assertFalse(c.connected)
def test_handle_error_namespace(self):
c = client.Client()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._handle_error('/bar')
self.assertEqual(c.namespaces, ['/foo'])
self.assertTrue(c.connected)
def test_handle_error_unknown_namespace(self):
c = client.Client()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._handle_error('/baz')
self.assertEqual(c.namespaces, ['/foo', '/bar'])
self.assertTrue(c.connected)
def test_trigger_event(self):
c = client.Client()
@ -667,6 +699,19 @@ class TestClient(unittest.TestCase):
c._trigger_event('foo', '/', 1, '2')
self.assertEqual(result, [1, '2'])
def test_trigger_event_unknown_namespace(self):
c = client.Client()
result = []
class MyNamespace(namespace.ClientNamespace):
def on_foo(self, a, b):
result.append(a)
result.append(b)
c.register_namespace(MyNamespace('/'))
c._trigger_event('foo', '/bar', 1, '2')
self.assertEqual(result, [])
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
def test_handle_reconnect(self, random):
c = client.Client()
@ -774,6 +819,7 @@ class TestClient(unittest.TestCase):
def test_eio_disconnect(self):
c = client.Client()
c.connected = True
c._trigger_event = mock.MagicMock()
c.start_background_task = mock.MagicMock()
c.sid = 'foo'
@ -781,9 +827,11 @@ class TestClient(unittest.TestCase):
c._handle_eio_disconnect()
c._trigger_event.assert_called_once_with('disconnect', namespace='/')
self.assertIsNone(c.sid)
self.assertFalse(c.connected)
def test_eio_disconnect_namespaces(self):
c = client.Client()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._trigger_event = mock.MagicMock()
c.start_background_task = mock.MagicMock()
@ -794,6 +842,7 @@ class TestClient(unittest.TestCase):
c._trigger_event.assert_any_call('disconnect', namespace='/bar')
c._trigger_event.assert_any_call('disconnect', namespace='/')
self.assertIsNone(c.sid)
self.assertFalse(c.connected)
def test_eio_disconnect_reconnect(self):
c = client.Client(reconnection=True)

Loading…
Cancel
Save