diff --git a/docs/api.rst b/docs/api.rst index 7139e2d..9a3e3df 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -32,6 +32,12 @@ API Reference :members: :inherited-members: +``ConnectionRefusedError`` class +-------------------------------- + +.. autoclass:: ConnectionRefusedError + :members: + ``WSGIApp`` class ----------------- diff --git a/docs/server.rst b/docs/server.rst index af856bd..d1b154b 100644 --- a/docs/server.rst +++ b/docs/server.rst @@ -111,6 +111,15 @@ standard WSGI format containing the request information, including HTTP headers. After inspecting the request, the connect event handler can return ``False`` to reject the connection with the client. +Sometimes it is useful to pass data back to the client being rejected. In that +case instead of returning ``False`` +:class:`socketio.exceptions.ConnectionRefusedError` can be raised, and all of +its argument will be sent to the client with the rejection:: + + @sio.on('connect') + def connect(sid, environ): + raise ConnectionRefusedError('authentication failed') + Emitting Events --------------- diff --git a/socketio/asyncio_server.py b/socketio/asyncio_server.py index d50b3db..ae7ee87 100644 --- a/socketio/asyncio_server.py +++ b/socketio/asyncio_server.py @@ -371,16 +371,23 @@ class AsyncServer(server.Server): if self.always_connect: await self._send_packet(sid, packet.Packet(packet.CONNECT, namespace=namespace)) - if await self._trigger_event('connect', namespace, sid, - self.environ[sid]) is False: + fail_reason = None + try: + success = await self._trigger_event('connect', namespace, sid, + self.environ[sid]) + except exceptions.ConnectionRefusedError as exc: + fail_reason = exc.error_args + success = False + + if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) await self._send_packet(sid, packet.Packet( - packet.DISCONNECT, namespace=namespace)) + packet.DISCONNECT, data=fail_reason, namespace=namespace)) self.manager.disconnect(sid, namespace) if not self.always_connect: await self._send_packet(sid, packet.Packet( - packet.ERROR, namespace=namespace)) + packet.ERROR, data=fail_reason, namespace=namespace)) if sid in self.environ: # pragma: no cover del self.environ[sid] return False diff --git a/socketio/exceptions.py b/socketio/exceptions.py index eb54efa..344aabb 100644 --- a/socketio/exceptions.py +++ b/socketio/exceptions.py @@ -6,5 +6,21 @@ class ConnectionError(SocketIOError): pass +class ConnectionRefusedError(ConnectionError): + """Connection refused exception. + + This exception can be raised from a connect handler when the connection + is not accepted. The positional arguments provided with the exception are + returned with the error packet to the client. + """ + def __init__(self, *args): + if len(args) == 0: + self.error_args = None + elif len(args) == 1: + self.error_args = args[0] + else: + self.error_args = args + + class TimeoutError(SocketIOError): pass diff --git a/socketio/server.py b/socketio/server.py index 7002039..87cc826 100644 --- a/socketio/server.py +++ b/socketio/server.py @@ -545,17 +545,23 @@ class Server(object): if self.always_connect: self._send_packet(sid, packet.Packet(packet.CONNECT, namespace=namespace)) - if self._trigger_event('connect', namespace, sid, - self.environ[sid]) is False: + fail_reason = None + try: + success = self._trigger_event('connect', namespace, sid, + self.environ[sid]) + except exceptions.ConnectionRefusedError as exc: + fail_reason = exc.error_args + success = False + + if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) - self._send_packet(sid, packet.Packet(packet.DISCONNECT, - namespace=namespace)) + self._send_packet(sid, packet.Packet( + packet.DISCONNECT, data=fail_reason, namespace=namespace)) self.manager.disconnect(sid, namespace) if not self.always_connect: - self._send_packet(sid, packet.Packet(packet.ERROR, - namespace=namespace)) - + self._send_packet(sid, packet.Packet( + packet.ERROR, data=fail_reason, namespace=namespace)) if sid in self.environ: # pragma: no cover del self.environ[sid] return False diff --git a/tests/asyncio/test_asyncio_server.py b/tests/asyncio/test_asyncio_server.py index 6c84154..d160b30 100644 --- a/tests/asyncio/test_asyncio_server.py +++ b/tests/asyncio/test_asyncio_server.py @@ -10,7 +10,7 @@ if six.PY3: else: import mock -from socketio import asyncio_server +from socketio import asyncio_server, exceptions from socketio import asyncio_namespace from socketio import exceptions from socketio import namespace @@ -340,6 +340,34 @@ class TestAsyncServer(unittest.TestCase): s.eio.send.mock.assert_any_call('123', '0/foo', binary=False) s.eio.send.mock.assert_any_call('123', '1/foo', binary=False) + def test_handle_connect_rejected_with_exception(self, eio): + eio.return_value.send = AsyncMock() + mgr = self._get_mock_manager() + s = asyncio_server.AsyncServer(client_manager=mgr) + handler = mock.MagicMock( + side_effect=exceptions.ConnectionRefusedError('fail_reason')) + s.on('connect', handler) + _run(s._handle_eio_connect('123', 'environ')) + self.assertEqual(s.manager.connect.call_count, 1) + self.assertEqual(s.manager.disconnect.call_count, 1) + self.assertEqual(s.environ, {}) + s.eio.send.mock.assert_any_call('123', '4"fail_reason"', binary=False) + + def test_handle_connect_namespace_rejected_with_exception(self, eio): + eio.return_value.send = AsyncMock() + mgr = self._get_mock_manager() + s = asyncio_server.AsyncServer(client_manager=mgr) + handler = mock.MagicMock( + side_effect=exceptions.ConnectionRefusedError('fail_reason', 1)) + s.on('connect', handler, namespace='/foo') + _run(s._handle_eio_connect('123', 'environ')) + _run(s._handle_eio_message('123', '0/foo')) + self.assertEqual(s.manager.connect.call_count, 2) + self.assertEqual(s.manager.disconnect.call_count, 1) + self.assertEqual(s.environ, {}) + s.eio.send.mock.assert_any_call('123', '4/foo,["fail_reason",1]', + binary=False) + def test_handle_disconnect(self, eio): eio.return_value.send = AsyncMock() mgr = self._get_mock_manager() diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 1f8a70c..7dbf812 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -277,6 +277,32 @@ class TestServer(unittest.TestCase): s.eio.send.assert_any_call('123', '0/foo', binary=False) s.eio.send.assert_any_call('123', '1/foo', binary=False) + def test_handle_connect_rejected_with_exception(self, eio): + mgr = mock.MagicMock() + s = server.Server(client_manager=mgr) + handler = mock.MagicMock( + side_effect=exceptions.ConnectionRefusedError()) + s.on('connect', handler) + s._handle_eio_connect('123', 'environ') + handler.assert_called_once_with('123', 'environ') + self.assertEqual(s.manager.connect.call_count, 1) + self.assertEqual(s.manager.disconnect.call_count, 1) + self.assertEqual(s.environ, {}) + s.eio.send.assert_any_call('123', '4', binary=False) + + def test_handle_connect_namespace_rejected_with_exception(self, eio): + mgr = mock.MagicMock() + s = server.Server(client_manager=mgr) + handler = mock.MagicMock( + side_effect=exceptions.ConnectionRefusedError('fail_reason')) + s.on('connect', handler, namespace='/foo') + s._handle_eio_connect('123', 'environ') + s._handle_eio_message('123', '0/foo') + self.assertEqual(s.manager.connect.call_count, 2) + self.assertEqual(s.manager.disconnect.call_count, 1) + print(s.eio.send.call_args) + s.eio.send.assert_any_call('123', '4/foo,"fail_reason"', binary=False) + def test_handle_disconnect(self, eio): mgr = mock.MagicMock() s = server.Server(client_manager=mgr)