From 8d08096dc442e817b5e4ce53321ccf196daafcd1 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Fri, 22 May 2020 19:09:57 +0100 Subject: [PATCH] Improved handling of rejected connections (#391 #487 #447) --- socketio/asyncio_client.py | 3 ++ socketio/asyncio_server.py | 7 ++-- socketio/client.py | 3 ++ socketio/server.py | 7 ++-- tests/asyncio/test_asyncio_server.py | 55 ++++++++++++++++++++++------ tests/common/test_server.py | 55 +++++++++++++++++++++++----- 6 files changed, 102 insertions(+), 28 deletions(-) diff --git a/socketio/asyncio_client.py b/socketio/asyncio_client.py index 9f8d47a..b848ecd 100644 --- a/socketio/asyncio_client.py +++ b/socketio/asyncio_client.py @@ -110,6 +110,9 @@ class AsyncClient(client.Client): transports=transports, engineio_path=socketio_path) except engineio.exceptions.ConnectionError as exc: + await self._trigger_event( + 'connect_error', '/', + exc.args[1] if len(exc.args) > 1 else exc.args[0]) six.raise_from(exceptions.ConnectionError(exc.args[0]), None) self.connected = True diff --git a/socketio/asyncio_server.py b/socketio/asyncio_server.py index 09cd5d2..20da548 100644 --- a/socketio/asyncio_server.py +++ b/socketio/asyncio_server.py @@ -426,12 +426,13 @@ class AsyncServer(server.Server): self.manager.pre_disconnect(sid, namespace) await self._send_packet(sid, packet.Packet( packet.DISCONNECT, data=fail_reason, namespace=namespace)) - self.manager.disconnect(sid, namespace) - if not self.always_connect: + elif namespace != '/': await self._send_packet(sid, packet.Packet( packet.ERROR, data=fail_reason, namespace=namespace)) - if sid in self.environ: # pragma: no cover + self.manager.disconnect(sid, namespace) + if namespace == '/' and sid in self.environ: # pragma: no cover del self.environ[sid] + return fail_reason or False elif not self.always_connect: await self._send_packet(sid, packet.Packet(packet.CONNECT, namespace=namespace)) diff --git a/socketio/client.py b/socketio/client.py index d836932..fb3f946 100644 --- a/socketio/client.py +++ b/socketio/client.py @@ -276,6 +276,9 @@ class Client(object): self.eio.connect(url, headers=headers, transports=transports, engineio_path=socketio_path) except engineio.exceptions.ConnectionError as exc: + self._trigger_event( + 'connect_error', '/', + exc.args[1] if len(exc.args) > 1 else exc.args[0]) six.raise_from(exceptions.ConnectionError(exc.args[0]), None) self.connected = True diff --git a/socketio/server.py b/socketio/server.py index 8c31707..01dcf94 100644 --- a/socketio/server.py +++ b/socketio/server.py @@ -635,12 +635,13 @@ class Server(object): self.manager.pre_disconnect(sid, namespace) self._send_packet(sid, packet.Packet( packet.DISCONNECT, data=fail_reason, namespace=namespace)) - self.manager.disconnect(sid, namespace) - if not self.always_connect: + elif namespace != '/': self._send_packet(sid, packet.Packet( packet.ERROR, data=fail_reason, namespace=namespace)) - if sid in self.environ: # pragma: no cover + self.manager.disconnect(sid, namespace) + if namespace == '/' and sid in self.environ: # pragma: no cover del self.environ[sid] + return fail_reason or False elif not self.always_connect: self._send_packet(sid, packet.Packet(packet.CONNECT, namespace=namespace)) diff --git a/tests/asyncio/test_asyncio_server.py b/tests/asyncio/test_asyncio_server.py index 201aebe..308a8a5 100644 --- a/tests/asyncio/test_asyncio_server.py +++ b/tests/asyncio/test_asyncio_server.py @@ -313,17 +313,16 @@ class TestAsyncServer(unittest.TestCase): s.eio.send.mock.assert_any_call('123', '0/foo', binary=False) def test_handle_connect_rejected(self, eio): - eio.return_value.send = AsyncMock() mgr = self._get_mock_manager() s = asyncio_server.AsyncServer(client_manager=mgr) handler = mock.MagicMock(return_value=False) s.on('connect', handler) - _run(s._handle_eio_connect('123', 'environ')) + ret = _run(s._handle_eio_connect('123', 'environ')) + self.assertFalse(ret) handler.assert_called_once_with('123', 'environ') self.assertEqual(s.manager.connect.call_count, 1) self.assertEqual(s.manager.disconnect.call_count, 1) self.assertEqual(s.environ, {}) - s.eio.send.mock.assert_called_once_with('123', '4', binary=False) def test_handle_connect_namespace_rejected(self, eio): eio.return_value.send = AsyncMock() @@ -331,11 +330,12 @@ class TestAsyncServer(unittest.TestCase): s = asyncio_server.AsyncServer(client_manager=mgr) handler = mock.MagicMock(return_value=False) s.on('connect', handler, namespace='/foo') - _run(s._handle_eio_connect('123', 'environ')) + ret = _run(s._handle_eio_connect('123', 'environ')) _run(s._handle_eio_message('123', '0/foo')) + self.assertIsNone(ret) self.assertEqual(s.manager.connect.call_count, 2) self.assertEqual(s.manager.disconnect.call_count, 1) - self.assertEqual(s.environ, {}) + self.assertEqual(s.environ, {'123': 'environ'}) s.eio.send.mock.assert_any_call('123', '4/foo', binary=False) def test_handle_connect_rejected_always_connect(self, eio): @@ -345,7 +345,8 @@ class TestAsyncServer(unittest.TestCase): always_connect=True) handler = mock.MagicMock(return_value=False) s.on('connect', handler) - _run(s._handle_eio_connect('123', 'environ')) + ret = _run(s._handle_eio_connect('123', 'environ')) + self.assertFalse(ret) handler.assert_called_once_with('123', 'environ') self.assertEqual(s.manager.connect.call_count, 1) self.assertEqual(s.manager.disconnect.call_count, 1) @@ -360,11 +361,12 @@ class TestAsyncServer(unittest.TestCase): always_connect=True) handler = mock.MagicMock(return_value=False) s.on('connect', handler, namespace='/foo') - _run(s._handle_eio_connect('123', 'environ')) + ret = _run(s._handle_eio_connect('123', 'environ')) _run(s._handle_eio_message('123', '0/foo')) + self.assertFalse(ret) self.assertEqual(s.manager.connect.call_count, 2) self.assertEqual(s.manager.disconnect.call_count, 1) - self.assertEqual(s.environ, {}) + self.assertEqual(s.environ, {'123': 'environ'}) s.eio.send.mock.assert_any_call('123', '0/foo', binary=False) s.eio.send.mock.assert_any_call('123', '1/foo', binary=False) @@ -375,11 +377,24 @@ class TestAsyncServer(unittest.TestCase): handler = mock.MagicMock( side_effect=exceptions.ConnectionRefusedError('fail_reason')) s.on('connect', handler) - _run(s._handle_eio_connect('123', 'environ')) + ret = _run(s._handle_eio_connect('123', 'environ')) + self.assertEqual(ret, 'fail_reason') + self.assertEqual(s.manager.connect.call_count, 1) + self.assertEqual(s.manager.disconnect.call_count, 1) + self.assertEqual(s.environ, {}) + + def test_handle_connect_rejected_with_empty_exception(self, eio): + eio.return_value.send = AsyncMock() + mgr = self._get_mock_manager() + s = asyncio_server.AsyncServer(client_manager=mgr) + handler = mock.MagicMock( + side_effect=exceptions.ConnectionRefusedError()) + s.on('connect', handler) + ret = _run(s._handle_eio_connect('123', 'environ')) + self.assertFalse(ret) self.assertEqual(s.manager.connect.call_count, 1) self.assertEqual(s.manager.disconnect.call_count, 1) self.assertEqual(s.environ, {}) - s.eio.send.mock.assert_any_call('123', '4"fail_reason"', binary=False) def test_handle_connect_namespace_rejected_with_exception(self, eio): eio.return_value.send = AsyncMock() @@ -388,14 +403,30 @@ class TestAsyncServer(unittest.TestCase): handler = mock.MagicMock( side_effect=exceptions.ConnectionRefusedError('fail_reason', 1)) s.on('connect', handler, namespace='/foo') - _run(s._handle_eio_connect('123', 'environ')) + ret = _run(s._handle_eio_connect('123', 'environ')) _run(s._handle_eio_message('123', '0/foo')) + self.assertIsNone(ret) self.assertEqual(s.manager.connect.call_count, 2) self.assertEqual(s.manager.disconnect.call_count, 1) - self.assertEqual(s.environ, {}) + self.assertEqual(s.environ, {'123': 'environ'}) s.eio.send.mock.assert_any_call('123', '4/foo,["fail_reason",1]', binary=False) + def test_handle_connect_namespace_rejected_with_empty_exception(self, eio): + eio.return_value.send = AsyncMock() + mgr = self._get_mock_manager() + s = asyncio_server.AsyncServer(client_manager=mgr) + handler = mock.MagicMock( + side_effect=exceptions.ConnectionRefusedError()) + s.on('connect', handler, namespace='/foo') + ret = _run(s._handle_eio_connect('123', 'environ')) + _run(s._handle_eio_message('123', '0/foo')) + self.assertIsNone(ret) + self.assertEqual(s.manager.connect.call_count, 2) + self.assertEqual(s.manager.disconnect.call_count, 1) + self.assertEqual(s.environ, {'123': 'environ'}) + s.eio.send.mock.assert_any_call('123', '4/foo', binary=False) + def test_handle_disconnect(self, eio): eio.return_value.send = AsyncMock() mgr = self._get_mock_manager() diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 44c9d89..2b14038 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -285,22 +285,24 @@ class TestServer(unittest.TestCase): s = server.Server(client_manager=mgr) handler = mock.MagicMock(return_value=False) s.on('connect', handler) - s._handle_eio_connect('123', 'environ') + ret = s._handle_eio_connect('123', 'environ') + self.assertFalse(ret) handler.assert_called_once_with('123', 'environ') self.assertEqual(s.manager.connect.call_count, 1) self.assertEqual(s.manager.disconnect.call_count, 1) self.assertEqual(s.environ, {}) - s.eio.send.assert_called_once_with('123', '4', binary=False) def test_handle_connect_namespace_rejected(self, eio): mgr = mock.MagicMock() s = server.Server(client_manager=mgr) handler = mock.MagicMock(return_value=False) s.on('connect', handler, namespace='/foo') - s._handle_eio_connect('123', 'environ') + ret = s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0/foo') + self.assertIsNone(ret) self.assertEqual(s.manager.connect.call_count, 2) self.assertEqual(s.manager.disconnect.call_count, 1) + self.assertEqual(s.environ, {'123': 'environ'}) s.eio.send.assert_any_call('123', '4/foo', binary=False) def test_handle_connect_rejected_always_connect(self, eio): @@ -308,7 +310,8 @@ class TestServer(unittest.TestCase): s = server.Server(client_manager=mgr, always_connect=True) handler = mock.MagicMock(return_value=False) s.on('connect', handler) - s._handle_eio_connect('123', 'environ') + ret = s._handle_eio_connect('123', 'environ') + self.assertFalse(ret) handler.assert_called_once_with('123', 'environ') self.assertEqual(s.manager.connect.call_count, 1) self.assertEqual(s.manager.disconnect.call_count, 1) @@ -321,37 +324,69 @@ class TestServer(unittest.TestCase): s = server.Server(client_manager=mgr, always_connect=True) handler = mock.MagicMock(return_value=False) s.on('connect', handler, namespace='/foo') - s._handle_eio_connect('123', 'environ') + ret = s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0/foo') + self.assertIsNone(ret) self.assertEqual(s.manager.connect.call_count, 2) self.assertEqual(s.manager.disconnect.call_count, 1) + self.assertEqual(s.environ, {'123': 'environ'}) s.eio.send.assert_any_call('123', '0/foo', binary=False) s.eio.send.assert_any_call('123', '1/foo', binary=False) def test_handle_connect_rejected_with_exception(self, eio): + mgr = mock.MagicMock() + s = server.Server(client_manager=mgr) + handler = mock.MagicMock( + side_effect=exceptions.ConnectionRefusedError('fail_reason')) + s.on('connect', handler) + ret = s._handle_eio_connect('123', 'environ') + self.assertEqual(ret, 'fail_reason') + handler.assert_called_once_with('123', 'environ') + self.assertEqual(s.manager.connect.call_count, 1) + self.assertEqual(s.manager.disconnect.call_count, 1) + self.assertEqual(s.environ, {}) + + def test_handle_connect_rejected_with_empty_exception(self, eio): mgr = mock.MagicMock() s = server.Server(client_manager=mgr) handler = mock.MagicMock( side_effect=exceptions.ConnectionRefusedError()) s.on('connect', handler) - s._handle_eio_connect('123', 'environ') + ret = s._handle_eio_connect('123', 'environ') + self.assertFalse(ret) handler.assert_called_once_with('123', 'environ') self.assertEqual(s.manager.connect.call_count, 1) self.assertEqual(s.manager.disconnect.call_count, 1) self.assertEqual(s.environ, {}) - s.eio.send.assert_any_call('123', '4', binary=False) def test_handle_connect_namespace_rejected_with_exception(self, eio): mgr = mock.MagicMock() s = server.Server(client_manager=mgr) handler = mock.MagicMock( - side_effect=exceptions.ConnectionRefusedError(u'fail_reason')) + side_effect=exceptions.ConnectionRefusedError(u'fail_reason', 1)) s.on('connect', handler, namespace='/foo') - s._handle_eio_connect('123', 'environ') + ret = s._handle_eio_connect('123', 'environ') + s._handle_eio_message('123', '0/foo') + self.assertIsNone(ret) + self.assertEqual(s.manager.connect.call_count, 2) + self.assertEqual(s.manager.disconnect.call_count, 1) + self.assertEqual(s.environ, {'123': 'environ'}) + s.eio.send.assert_any_call('123', '4/foo,["fail_reason",1]', + binary=False) + + def test_handle_connect_namespace_rejected_with_empty_exception(self, eio): + mgr = mock.MagicMock() + s = server.Server(client_manager=mgr) + handler = mock.MagicMock( + side_effect=exceptions.ConnectionRefusedError()) + s.on('connect', handler, namespace='/foo') + ret = s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0/foo') + self.assertIsNone(ret) self.assertEqual(s.manager.connect.call_count, 2) self.assertEqual(s.manager.disconnect.call_count, 1) - s.eio.send.assert_any_call('123', '4/foo,"fail_reason"', binary=False) + self.assertEqual(s.environ, {'123': 'environ'}) + s.eio.send.assert_any_call('123', '4/foo', binary=False) def test_handle_disconnect(self, eio): mgr = mock.MagicMock()