From f3b5210289fd87f2780ecdcc4b89e3f9b463370b Mon Sep 17 00:00:00 2001
From: Andrey Rusanov <rusanov.andrey.mail@gmail.com>
Date: Sun, 3 Feb 2019 09:15:17 +0200
Subject: [PATCH] Add tests and docs

---
 docs/api.rst                         |  6 ++++++
 docs/server.rst                      | 12 +++++++++++
 tests/asyncio/test_asyncio_server.py | 32 +++++++++++++++++++++++++++-
 tests/common/test_server.py          | 27 +++++++++++++++++++++++
 4 files changed, 76 insertions(+), 1 deletion(-)

diff --git a/docs/api.rst b/docs/api.rst
index 7139e2d..3fe850f 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -114,3 +114,9 @@ API Reference
 
 .. autoclass:: AsyncRedisManager
    :members:
+
+``ConnectionRefusedError`` class
+--------------------------------
+
+.. autoclass:: ConnectionRefusedError
+   :members:
diff --git a/docs/server.rst b/docs/server.rst
index af856bd..22513be 100644
--- a/docs/server.rst
+++ b/docs/server.rst
@@ -111,6 +111,18 @@ standard WSGI format containing the request information, including HTTP
 headers. After inspecting the request, the connect event handler can return
 ``False`` to reject the connection with the client.
 
+If any additional data has to be passed on connection reject, than instead of
+returning ``False`` :class:`socketio.exceptions.ConnectionRefusedError` could
+be raised:
+
+    @sio.on('connect')
+    def connect(sid, environ):
+        message = 'Incorrect user data'
+        raise ConnectionRefusedError(message)
+
+In this case message will be returned directly to the client with rejected
+connection.
+
 Emitting Events
 ---------------
 
diff --git a/tests/asyncio/test_asyncio_server.py b/tests/asyncio/test_asyncio_server.py
index 4761050..4a8f00c 100644
--- a/tests/asyncio/test_asyncio_server.py
+++ b/tests/asyncio/test_asyncio_server.py
@@ -10,7 +10,7 @@ if six.PY3:
 else:
     import mock
 
-from socketio import asyncio_server
+from socketio import asyncio_server, exceptions
 from socketio import asyncio_namespace
 from socketio import packet
 from socketio import namespace
@@ -276,6 +276,36 @@ class TestAsyncServer(unittest.TestCase):
         self.assertEqual(s.environ, {})
         s.eio.send.mock.assert_any_call('123', '4/foo', binary=False)
 
+    def test_handle_connect_namespace_rejected_with_exception(self, eio):
+        eio.return_value.send = AsyncMock()
+        mgr = self._get_mock_manager()
+        s = asyncio_server.AsyncServer(client_manager=mgr)
+        handler = mock.MagicMock(side_effect=exceptions.ConnectionRefusedError('fail_reason'))
+        s.on('connect', handler, namespace='/foo')
+        _run(s._handle_eio_connect('123', 'environ'))
+        _run(s._handle_eio_message('123', '0/foo'))
+        self.assertEqual(s.manager.connect.call_count, 2)
+        self.assertEqual(s.manager.disconnect.call_count, 1)
+        self.assertEqual(s.environ, {})
+        s.eio.send.mock.assert_any_call('123', '4/foo,"fail_reason"', binary=False)
+
+    def test_handle_connect_namespace_rejected_with_custom_exception(self, eio):
+        class CustomizedConnRefused(exceptions.ConnectionRefusedError):
+            def get_info(self):
+                return 'customized: {}'.format(self._info)
+
+        eio.return_value.send = AsyncMock()
+        mgr = self._get_mock_manager()
+        s = asyncio_server.AsyncServer(client_manager=mgr)
+        handler = mock.MagicMock(side_effect=CustomizedConnRefused('fail_reason'))
+        s.on('connect', handler, namespace='/foo')
+        _run(s._handle_eio_connect('123', 'environ'))
+        _run(s._handle_eio_message('123', '0/foo'))
+        self.assertEqual(s.manager.connect.call_count, 2)
+        self.assertEqual(s.manager.disconnect.call_count, 1)
+        self.assertEqual(s.environ, {})
+        s.eio.send.mock.assert_any_call('123', '4/foo,"customized: fail_reason"', binary=False)
+
     def test_handle_disconnect(self, eio):
         eio.return_value.send = AsyncMock()
         mgr = self._get_mock_manager()
diff --git a/tests/common/test_server.py b/tests/common/test_server.py
index 5e16ec5..61df07d 100644
--- a/tests/common/test_server.py
+++ b/tests/common/test_server.py
@@ -8,6 +8,7 @@ if six.PY3:
 else:
     import mock
 
+from socketio import exceptions
 from socketio import packet
 from socketio import server
 from socketio import namespace
@@ -218,6 +219,32 @@ class TestServer(unittest.TestCase):
         self.assertEqual(s.manager.disconnect.call_count, 1)
         s.eio.send.assert_any_call('123', '4/foo', binary=False)
 
+    def test_handle_connect_namespace_rejected_with_exception(self, eio):
+        mgr = mock.MagicMock()
+        s = server.Server(client_manager=mgr)
+        handler = mock.MagicMock(side_effect=exceptions.ConnectionRefusedError('fail_reason'))
+        s.on('connect', handler, namespace='/foo')
+        s._handle_eio_connect('123', 'environ')
+        s._handle_eio_message('123', '0/foo')
+        self.assertEqual(s.manager.connect.call_count, 2)
+        self.assertEqual(s.manager.disconnect.call_count, 1)
+        s.eio.send.assert_any_call('123', '4/foo,"fail_reason"', binary=False)
+
+    def test_handle_connect_namespace_rejected_with_custom_exception(self, eio):
+        class CustomizedConnRefused(exceptions.ConnectionRefusedError):
+            def get_info(self):
+                return 'customized: {}'.format(self._info)
+
+        mgr = mock.MagicMock()
+        s = server.Server(client_manager=mgr)
+        handler = mock.MagicMock(side_effect=CustomizedConnRefused('fail_reason'))
+        s.on('connect', handler, namespace='/foo')
+        s._handle_eio_connect('123', 'environ')
+        s._handle_eio_message('123', '0/foo')
+        self.assertEqual(s.manager.connect.call_count, 2)
+        self.assertEqual(s.manager.disconnect.call_count, 1)
+        s.eio.send.assert_any_call('123', '4/foo,"customized: fail_reason"', binary=False)
+
     def test_handle_disconnect(self, eio):
         mgr = mock.MagicMock()
         s = server.Server(client_manager=mgr)