Browse Source

Merge branch 'andreyrusanov-connection_refused_error'

pull/269/head
Miguel Grinberg 6 years ago
parent
commit
1154e43faf
No known key found for this signature in database GPG Key ID: 36848B262DF5F06C
  1. 6
      docs/api.rst
  2. 9
      docs/server.rst
  3. 15
      socketio/asyncio_server.py
  4. 16
      socketio/exceptions.py
  5. 20
      socketio/server.py
  6. 30
      tests/asyncio/test_asyncio_server.py
  7. 26
      tests/common/test_server.py

6
docs/api.rst

@ -32,6 +32,12 @@ API Reference
:members: :members:
:inherited-members: :inherited-members:
``ConnectionRefusedError`` class
--------------------------------
.. autoclass:: ConnectionRefusedError
:members:
``WSGIApp`` class ``WSGIApp`` class
----------------- -----------------

9
docs/server.rst

@ -111,6 +111,15 @@ standard WSGI format containing the request information, including HTTP
headers. After inspecting the request, the connect event handler can return headers. After inspecting the request, the connect event handler can return
``False`` to reject the connection with the client. ``False`` to reject the connection with the client.
Sometimes it is useful to pass data back to the client being rejected. In that
case instead of returning ``False``
:class:`socketio.exceptions.ConnectionRefusedError` can be raised, and all of
its argument will be sent to the client with the rejection::
@sio.on('connect')
def connect(sid, environ):
raise ConnectionRefusedError('authentication failed')
Emitting Events Emitting Events
--------------- ---------------

15
socketio/asyncio_server.py

@ -371,16 +371,23 @@ class AsyncServer(server.Server):
if self.always_connect: if self.always_connect:
await self._send_packet(sid, packet.Packet(packet.CONNECT, await self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace)) namespace=namespace))
if await self._trigger_event('connect', namespace, sid, fail_reason = None
self.environ[sid]) is False: try:
success = await self._trigger_event('connect', namespace, sid,
self.environ[sid])
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args
success = False
if success is False:
if self.always_connect: if self.always_connect:
self.manager.pre_disconnect(sid, namespace) self.manager.pre_disconnect(sid, namespace)
await self._send_packet(sid, packet.Packet( await self._send_packet(sid, packet.Packet(
packet.DISCONNECT, namespace=namespace)) packet.DISCONNECT, data=fail_reason, namespace=namespace))
self.manager.disconnect(sid, namespace) self.manager.disconnect(sid, namespace)
if not self.always_connect: if not self.always_connect:
await self._send_packet(sid, packet.Packet( await self._send_packet(sid, packet.Packet(
packet.ERROR, namespace=namespace)) packet.ERROR, data=fail_reason, namespace=namespace))
if sid in self.environ: # pragma: no cover if sid in self.environ: # pragma: no cover
del self.environ[sid] del self.environ[sid]
return False return False

16
socketio/exceptions.py

@ -6,5 +6,21 @@ class ConnectionError(SocketIOError):
pass pass
class ConnectionRefusedError(ConnectionError):
"""Connection refused exception.
This exception can be raised from a connect handler when the connection
is not accepted. The positional arguments provided with the exception are
returned with the error packet to the client.
"""
def __init__(self, *args):
if len(args) == 0:
self.error_args = None
elif len(args) == 1:
self.error_args = args[0]
else:
self.error_args = args
class TimeoutError(SocketIOError): class TimeoutError(SocketIOError):
pass pass

20
socketio/server.py

@ -545,17 +545,23 @@ class Server(object):
if self.always_connect: if self.always_connect:
self._send_packet(sid, packet.Packet(packet.CONNECT, self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace)) namespace=namespace))
if self._trigger_event('connect', namespace, sid, fail_reason = None
self.environ[sid]) is False: try:
success = self._trigger_event('connect', namespace, sid,
self.environ[sid])
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args
success = False
if success is False:
if self.always_connect: if self.always_connect:
self.manager.pre_disconnect(sid, namespace) self.manager.pre_disconnect(sid, namespace)
self._send_packet(sid, packet.Packet(packet.DISCONNECT, self._send_packet(sid, packet.Packet(
namespace=namespace)) packet.DISCONNECT, data=fail_reason, namespace=namespace))
self.manager.disconnect(sid, namespace) self.manager.disconnect(sid, namespace)
if not self.always_connect: if not self.always_connect:
self._send_packet(sid, packet.Packet(packet.ERROR, self._send_packet(sid, packet.Packet(
namespace=namespace)) packet.ERROR, data=fail_reason, namespace=namespace))
if sid in self.environ: # pragma: no cover if sid in self.environ: # pragma: no cover
del self.environ[sid] del self.environ[sid]
return False return False

30
tests/asyncio/test_asyncio_server.py

@ -10,7 +10,7 @@ if six.PY3:
else: else:
import mock import mock
from socketio import asyncio_server from socketio import asyncio_server, exceptions
from socketio import asyncio_namespace from socketio import asyncio_namespace
from socketio import exceptions from socketio import exceptions
from socketio import namespace from socketio import namespace
@ -340,6 +340,34 @@ class TestAsyncServer(unittest.TestCase):
s.eio.send.mock.assert_any_call('123', '0/foo', binary=False) s.eio.send.mock.assert_any_call('123', '0/foo', binary=False)
s.eio.send.mock.assert_any_call('123', '1/foo', binary=False) s.eio.send.mock.assert_any_call('123', '1/foo', binary=False)
def test_handle_connect_rejected_with_exception(self, eio):
eio.return_value.send = AsyncMock()
mgr = self._get_mock_manager()
s = asyncio_server.AsyncServer(client_manager=mgr)
handler = mock.MagicMock(
side_effect=exceptions.ConnectionRefusedError('fail_reason'))
s.on('connect', handler)
_run(s._handle_eio_connect('123', 'environ'))
self.assertEqual(s.manager.connect.call_count, 1)
self.assertEqual(s.manager.disconnect.call_count, 1)
self.assertEqual(s.environ, {})
s.eio.send.mock.assert_any_call('123', '4"fail_reason"', binary=False)
def test_handle_connect_namespace_rejected_with_exception(self, eio):
eio.return_value.send = AsyncMock()
mgr = self._get_mock_manager()
s = asyncio_server.AsyncServer(client_manager=mgr)
handler = mock.MagicMock(
side_effect=exceptions.ConnectionRefusedError('fail_reason', 1))
s.on('connect', handler, namespace='/foo')
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0/foo'))
self.assertEqual(s.manager.connect.call_count, 2)
self.assertEqual(s.manager.disconnect.call_count, 1)
self.assertEqual(s.environ, {})
s.eio.send.mock.assert_any_call('123', '4/foo,["fail_reason",1]',
binary=False)
def test_handle_disconnect(self, eio): def test_handle_disconnect(self, eio):
eio.return_value.send = AsyncMock() eio.return_value.send = AsyncMock()
mgr = self._get_mock_manager() mgr = self._get_mock_manager()

26
tests/common/test_server.py

@ -277,6 +277,32 @@ class TestServer(unittest.TestCase):
s.eio.send.assert_any_call('123', '0/foo', binary=False) s.eio.send.assert_any_call('123', '0/foo', binary=False)
s.eio.send.assert_any_call('123', '1/foo', binary=False) s.eio.send.assert_any_call('123', '1/foo', binary=False)
def test_handle_connect_rejected_with_exception(self, eio):
mgr = mock.MagicMock()
s = server.Server(client_manager=mgr)
handler = mock.MagicMock(
side_effect=exceptions.ConnectionRefusedError())
s.on('connect', handler)
s._handle_eio_connect('123', 'environ')
handler.assert_called_once_with('123', 'environ')
self.assertEqual(s.manager.connect.call_count, 1)
self.assertEqual(s.manager.disconnect.call_count, 1)
self.assertEqual(s.environ, {})
s.eio.send.assert_any_call('123', '4', binary=False)
def test_handle_connect_namespace_rejected_with_exception(self, eio):
mgr = mock.MagicMock()
s = server.Server(client_manager=mgr)
handler = mock.MagicMock(
side_effect=exceptions.ConnectionRefusedError('fail_reason'))
s.on('connect', handler, namespace='/foo')
s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0/foo')
self.assertEqual(s.manager.connect.call_count, 2)
self.assertEqual(s.manager.disconnect.call_count, 1)
print(s.eio.send.call_args)
s.eio.send.assert_any_call('123', '4/foo,"fail_reason"', binary=False)
def test_handle_disconnect(self, eio): def test_handle_disconnect(self, eio):
mgr = mock.MagicMock() mgr = mock.MagicMock()
s = server.Server(client_manager=mgr) s = server.Server(client_manager=mgr)

Loading…
Cancel
Save