Browse Source

add ConnectionRefusedError exception

pull/269/head
Miguel Grinberg 6 years ago
parent
commit
605d1acfcd
No known key found for this signature in database GPG Key ID: 36848B262DF5F06C
  1. 6
      docs/api.rst
  2. 9
      docs/server.rst
  3. 15
      socketio/asyncio_server.py
  4. 16
      socketio/exceptions.py
  5. 20
      socketio/server.py
  6. 30
      tests/asyncio/test_asyncio_server.py
  7. 26
      tests/common/test_server.py

6
docs/api.rst

@ -32,6 +32,12 @@ API Reference
:members:
:inherited-members:
``ConnectionRefusedError`` class
--------------------------------
.. autoclass:: ConnectionRefusedError
:members:
``WSGIApp`` class
-----------------

9
docs/server.rst

@ -111,6 +111,15 @@ standard WSGI format containing the request information, including HTTP
headers. After inspecting the request, the connect event handler can return
``False`` to reject the connection with the client.
Sometimes it is useful to pass data back to the client being rejected. In that
case instead of returning ``False``
:class:`socketio.exceptions.ConnectionRefusedError` can be raised, and all of
its argument will be sent to the client with the rejection::
@sio.on('connect')
def connect(sid, environ):
raise ConnectionRefusedError('authentication failed')
Emitting Events
---------------

15
socketio/asyncio_server.py

@ -371,16 +371,23 @@ class AsyncServer(server.Server):
if self.always_connect:
await self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace))
if await self._trigger_event('connect', namespace, sid,
self.environ[sid]) is False:
fail_reason = None
try:
success = await self._trigger_event('connect', namespace, sid,
self.environ[sid])
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args
success = False
if success is False:
if self.always_connect:
self.manager.pre_disconnect(sid, namespace)
await self._send_packet(sid, packet.Packet(
packet.DISCONNECT, namespace=namespace))
packet.DISCONNECT, data=fail_reason, namespace=namespace))
self.manager.disconnect(sid, namespace)
if not self.always_connect:
await self._send_packet(sid, packet.Packet(
packet.ERROR, namespace=namespace))
packet.ERROR, data=fail_reason, namespace=namespace))
if sid in self.environ: # pragma: no cover
del self.environ[sid]
return False

16
socketio/exceptions.py

@ -6,5 +6,21 @@ class ConnectionError(SocketIOError):
pass
class ConnectionRefusedError(ConnectionError):
"""Connection refused exception.
This exception can be raised from a connect handler when the connection
is not accepted. The positional arguments provided with the exception are
returned with the error packet to the client.
"""
def __init__(self, *args):
if len(args) == 0:
self.error_args = None
elif len(args) == 1:
self.error_args = args[0]
else:
self.error_args = args
class TimeoutError(SocketIOError):
pass

20
socketio/server.py

@ -545,17 +545,23 @@ class Server(object):
if self.always_connect:
self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace))
if self._trigger_event('connect', namespace, sid,
self.environ[sid]) is False:
fail_reason = None
try:
success = self._trigger_event('connect', namespace, sid,
self.environ[sid])
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args
success = False
if success is False:
if self.always_connect:
self.manager.pre_disconnect(sid, namespace)
self._send_packet(sid, packet.Packet(packet.DISCONNECT,
namespace=namespace))
self._send_packet(sid, packet.Packet(
packet.DISCONNECT, data=fail_reason, namespace=namespace))
self.manager.disconnect(sid, namespace)
if not self.always_connect:
self._send_packet(sid, packet.Packet(packet.ERROR,
namespace=namespace))
self._send_packet(sid, packet.Packet(
packet.ERROR, data=fail_reason, namespace=namespace))
if sid in self.environ: # pragma: no cover
del self.environ[sid]
return False

30
tests/asyncio/test_asyncio_server.py

@ -10,7 +10,7 @@ if six.PY3:
else:
import mock
from socketio import asyncio_server
from socketio import asyncio_server, exceptions
from socketio import asyncio_namespace
from socketio import exceptions
from socketio import namespace
@ -340,6 +340,34 @@ class TestAsyncServer(unittest.TestCase):
s.eio.send.mock.assert_any_call('123', '0/foo', binary=False)
s.eio.send.mock.assert_any_call('123', '1/foo', binary=False)
def test_handle_connect_rejected_with_exception(self, eio):
eio.return_value.send = AsyncMock()
mgr = self._get_mock_manager()
s = asyncio_server.AsyncServer(client_manager=mgr)
handler = mock.MagicMock(
side_effect=exceptions.ConnectionRefusedError('fail_reason'))
s.on('connect', handler)
_run(s._handle_eio_connect('123', 'environ'))
self.assertEqual(s.manager.connect.call_count, 1)
self.assertEqual(s.manager.disconnect.call_count, 1)
self.assertEqual(s.environ, {})
s.eio.send.mock.assert_any_call('123', '4"fail_reason"', binary=False)
def test_handle_connect_namespace_rejected_with_exception(self, eio):
eio.return_value.send = AsyncMock()
mgr = self._get_mock_manager()
s = asyncio_server.AsyncServer(client_manager=mgr)
handler = mock.MagicMock(
side_effect=exceptions.ConnectionRefusedError('fail_reason', 1))
s.on('connect', handler, namespace='/foo')
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0/foo'))
self.assertEqual(s.manager.connect.call_count, 2)
self.assertEqual(s.manager.disconnect.call_count, 1)
self.assertEqual(s.environ, {})
s.eio.send.mock.assert_any_call('123', '4/foo,["fail_reason",1]',
binary=False)
def test_handle_disconnect(self, eio):
eio.return_value.send = AsyncMock()
mgr = self._get_mock_manager()

26
tests/common/test_server.py

@ -277,6 +277,32 @@ class TestServer(unittest.TestCase):
s.eio.send.assert_any_call('123', '0/foo', binary=False)
s.eio.send.assert_any_call('123', '1/foo', binary=False)
def test_handle_connect_rejected_with_exception(self, eio):
mgr = mock.MagicMock()
s = server.Server(client_manager=mgr)
handler = mock.MagicMock(
side_effect=exceptions.ConnectionRefusedError())
s.on('connect', handler)
s._handle_eio_connect('123', 'environ')
handler.assert_called_once_with('123', 'environ')
self.assertEqual(s.manager.connect.call_count, 1)
self.assertEqual(s.manager.disconnect.call_count, 1)
self.assertEqual(s.environ, {})
s.eio.send.assert_any_call('123', '4', binary=False)
def test_handle_connect_namespace_rejected_with_exception(self, eio):
mgr = mock.MagicMock()
s = server.Server(client_manager=mgr)
handler = mock.MagicMock(
side_effect=exceptions.ConnectionRefusedError('fail_reason'))
s.on('connect', handler, namespace='/foo')
s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0/foo')
self.assertEqual(s.manager.connect.call_count, 2)
self.assertEqual(s.manager.disconnect.call_count, 1)
print(s.eio.send.call_args)
s.eio.send.assert_any_call('123', '4/foo,"fail_reason"', binary=False)
def test_handle_disconnect(self, eio):
mgr = mock.MagicMock()
s = server.Server(client_manager=mgr)

Loading…
Cancel
Save