From 38edd2c93950d85d97258c104af24792b01dc1e3 Mon Sep 17 00:00:00 2001 From: Andrey Rusanov Date: Mon, 28 Jan 2019 13:33:38 +0200 Subject: [PATCH 1/2] Add ConnectionRefusedError and handling for it --- socketio/asyncio_server.py | 16 ++++++++++++---- socketio/exceptions.py | 14 ++++++++++++++ socketio/server.py | 15 +++++++++++---- 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/socketio/asyncio_server.py b/socketio/asyncio_server.py index cbd812b..f0e4679 100644 --- a/socketio/asyncio_server.py +++ b/socketio/asyncio_server.py @@ -3,6 +3,7 @@ import asyncio import engineio from . import asyncio_manager +from . import exceptions from . import packet from . import server @@ -320,11 +321,18 @@ class AsyncServer(server.Server): """Handle a client connection request.""" namespace = namespace or '/' self.manager.connect(sid, namespace) - if await self._trigger_event('connect', namespace, sid, - self.environ[sid]) is False: + + try: + success = await self._trigger_event('connect', namespace, sid, self.environ[sid]) + except exceptions.ConnectionRefusedError as exc: + fail_reason = exc.get_info() + success = False + else: + fail_reason = None + + if success is False: self.manager.disconnect(sid, namespace) - await self._send_packet(sid, packet.Packet(packet.ERROR, - namespace=namespace)) + await self._send_packet(sid, packet.Packet(packet.ERROR, data=fail_reason, namespace=namespace)) if sid in self.environ: # pragma: no cover del self.environ[sid] return False diff --git a/socketio/exceptions.py b/socketio/exceptions.py index 5bd8697..2c325f2 100644 --- a/socketio/exceptions.py +++ b/socketio/exceptions.py @@ -4,3 +4,17 @@ class SocketIOError(Exception): class ConnectionError(SocketIOError): pass + + +class ConnectionRefusedError(ConnectionError): + """ + Raised when connection is refused on the application level + """ + def __init__(self, info): + self._info = info + + def get_info(self): + """ + This method could be overridden in subclass to add extra logic for data output + """ + return self._info diff --git a/socketio/server.py b/socketio/server.py index 449c94a..8151e53 100644 --- a/socketio/server.py +++ b/socketio/server.py @@ -4,6 +4,7 @@ import engineio import six from . import base_manager +from . import exceptions from . import packet from . import namespace @@ -485,11 +486,17 @@ class Server(object): """Handle a client connection request.""" namespace = namespace or '/' self.manager.connect(sid, namespace) - if self._trigger_event('connect', namespace, sid, - self.environ[sid]) is False: + try: + success = self._trigger_event('connect', namespace, sid, self.environ[sid]) + except exceptions.ConnectionRefusedError as exc: + fail_reason = exc.get_info() + success = False + else: + fail_reason = None + + if success is False: self.manager.disconnect(sid, namespace) - self._send_packet(sid, packet.Packet(packet.ERROR, - namespace=namespace)) + self._send_packet(sid, packet.Packet(packet.ERROR, data=fail_reason, namespace=namespace)) if sid in self.environ: # pragma: no cover del self.environ[sid] return False From f3b5210289fd87f2780ecdcc4b89e3f9b463370b Mon Sep 17 00:00:00 2001 From: Andrey Rusanov Date: Sun, 3 Feb 2019 09:15:17 +0200 Subject: [PATCH 2/2] Add tests and docs --- docs/api.rst | 6 ++++++ docs/server.rst | 12 +++++++++++ tests/asyncio/test_asyncio_server.py | 32 +++++++++++++++++++++++++++- tests/common/test_server.py | 27 +++++++++++++++++++++++ 4 files changed, 76 insertions(+), 1 deletion(-) diff --git a/docs/api.rst b/docs/api.rst index 7139e2d..3fe850f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -114,3 +114,9 @@ API Reference .. autoclass:: AsyncRedisManager :members: + +``ConnectionRefusedError`` class +-------------------------------- + +.. autoclass:: ConnectionRefusedError + :members: diff --git a/docs/server.rst b/docs/server.rst index af856bd..22513be 100644 --- a/docs/server.rst +++ b/docs/server.rst @@ -111,6 +111,18 @@ standard WSGI format containing the request information, including HTTP headers. After inspecting the request, the connect event handler can return ``False`` to reject the connection with the client. +If any additional data has to be passed on connection reject, than instead of +returning ``False`` :class:`socketio.exceptions.ConnectionRefusedError` could +be raised: + + @sio.on('connect') + def connect(sid, environ): + message = 'Incorrect user data' + raise ConnectionRefusedError(message) + +In this case message will be returned directly to the client with rejected +connection. + Emitting Events --------------- diff --git a/tests/asyncio/test_asyncio_server.py b/tests/asyncio/test_asyncio_server.py index 4761050..4a8f00c 100644 --- a/tests/asyncio/test_asyncio_server.py +++ b/tests/asyncio/test_asyncio_server.py @@ -10,7 +10,7 @@ if six.PY3: else: import mock -from socketio import asyncio_server +from socketio import asyncio_server, exceptions from socketio import asyncio_namespace from socketio import packet from socketio import namespace @@ -276,6 +276,36 @@ class TestAsyncServer(unittest.TestCase): self.assertEqual(s.environ, {}) s.eio.send.mock.assert_any_call('123', '4/foo', binary=False) + def test_handle_connect_namespace_rejected_with_exception(self, eio): + eio.return_value.send = AsyncMock() + mgr = self._get_mock_manager() + s = asyncio_server.AsyncServer(client_manager=mgr) + handler = mock.MagicMock(side_effect=exceptions.ConnectionRefusedError('fail_reason')) + s.on('connect', handler, namespace='/foo') + _run(s._handle_eio_connect('123', 'environ')) + _run(s._handle_eio_message('123', '0/foo')) + self.assertEqual(s.manager.connect.call_count, 2) + self.assertEqual(s.manager.disconnect.call_count, 1) + self.assertEqual(s.environ, {}) + s.eio.send.mock.assert_any_call('123', '4/foo,"fail_reason"', binary=False) + + def test_handle_connect_namespace_rejected_with_custom_exception(self, eio): + class CustomizedConnRefused(exceptions.ConnectionRefusedError): + def get_info(self): + return 'customized: {}'.format(self._info) + + eio.return_value.send = AsyncMock() + mgr = self._get_mock_manager() + s = asyncio_server.AsyncServer(client_manager=mgr) + handler = mock.MagicMock(side_effect=CustomizedConnRefused('fail_reason')) + s.on('connect', handler, namespace='/foo') + _run(s._handle_eio_connect('123', 'environ')) + _run(s._handle_eio_message('123', '0/foo')) + self.assertEqual(s.manager.connect.call_count, 2) + self.assertEqual(s.manager.disconnect.call_count, 1) + self.assertEqual(s.environ, {}) + s.eio.send.mock.assert_any_call('123', '4/foo,"customized: fail_reason"', binary=False) + def test_handle_disconnect(self, eio): eio.return_value.send = AsyncMock() mgr = self._get_mock_manager() diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 5e16ec5..61df07d 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -8,6 +8,7 @@ if six.PY3: else: import mock +from socketio import exceptions from socketio import packet from socketio import server from socketio import namespace @@ -218,6 +219,32 @@ class TestServer(unittest.TestCase): self.assertEqual(s.manager.disconnect.call_count, 1) s.eio.send.assert_any_call('123', '4/foo', binary=False) + def test_handle_connect_namespace_rejected_with_exception(self, eio): + mgr = mock.MagicMock() + s = server.Server(client_manager=mgr) + handler = mock.MagicMock(side_effect=exceptions.ConnectionRefusedError('fail_reason')) + s.on('connect', handler, namespace='/foo') + s._handle_eio_connect('123', 'environ') + s._handle_eio_message('123', '0/foo') + self.assertEqual(s.manager.connect.call_count, 2) + self.assertEqual(s.manager.disconnect.call_count, 1) + s.eio.send.assert_any_call('123', '4/foo,"fail_reason"', binary=False) + + def test_handle_connect_namespace_rejected_with_custom_exception(self, eio): + class CustomizedConnRefused(exceptions.ConnectionRefusedError): + def get_info(self): + return 'customized: {}'.format(self._info) + + mgr = mock.MagicMock() + s = server.Server(client_manager=mgr) + handler = mock.MagicMock(side_effect=CustomizedConnRefused('fail_reason')) + s.on('connect', handler, namespace='/foo') + s._handle_eio_connect('123', 'environ') + s._handle_eio_message('123', '0/foo') + self.assertEqual(s.manager.connect.call_count, 2) + self.assertEqual(s.manager.disconnect.call_count, 1) + s.eio.send.assert_any_call('123', '4/foo,"customized: fail_reason"', binary=False) + def test_handle_disconnect(self, eio): mgr = mock.MagicMock() s = server.Server(client_manager=mgr)