diff --git a/socketio/asyncio_server.py b/socketio/asyncio_server.py index 99919ab..d50b3db 100644 --- a/socketio/asyncio_server.py +++ b/socketio/asyncio_server.py @@ -368,17 +368,25 @@ class AsyncServer(server.Server): """Handle a client connection request.""" namespace = namespace or '/' self.manager.connect(sid, namespace) - await self._send_packet(sid, packet.Packet(packet.CONNECT, - namespace=namespace)) + if self.always_connect: + await self._send_packet(sid, packet.Packet(packet.CONNECT, + namespace=namespace)) if await self._trigger_event('connect', namespace, sid, self.environ[sid]) is False: - self.manager.pre_disconnect(sid, namespace) - await self._send_packet(sid, packet.Packet(packet.DISCONNECT, - namespace=namespace)) + if self.always_connect: + self.manager.pre_disconnect(sid, namespace) + await self._send_packet(sid, packet.Packet( + packet.DISCONNECT, namespace=namespace)) self.manager.disconnect(sid, namespace) + if not self.always_connect: + await self._send_packet(sid, packet.Packet( + packet.ERROR, namespace=namespace)) if sid in self.environ: # pragma: no cover del self.environ[sid] return False + elif not self.always_connect: + await self._send_packet(sid, packet.Packet(packet.CONNECT, + namespace=namespace)) async def _handle_disconnect(self, sid, namespace): """Handle a client disconnect.""" diff --git a/socketio/server.py b/socketio/server.py index 351d849..7002039 100644 --- a/socketio/server.py +++ b/socketio/server.py @@ -38,6 +38,16 @@ class Server(object): executed in separate threads. To run handlers for a client synchronously, set to ``False``. The default is ``True``. + :param always_connect: When set to ``False``, new connections are + provisory until the connect handler returns + something other than ``False``, at which point they + are accepted. When set to ``True``, connections are + immediately accepted, and then if the connect + handler returns ``False`` a disconnect is issued. + Set to ``True`` if you need to emit events from the + connect handler and your client is confused when it + receives events before the connection acceptance. + In any other case use the default of ``False``. :param kwargs: Connection parameters for the underlying Engine.IO server. The Engine.IO configuration supports the following settings: @@ -79,7 +89,8 @@ class Server(object): ``False``. The default is ``False``. """ def __init__(self, client_manager=None, logger=False, binary=False, - json=None, async_handlers=True, **kwargs): + json=None, async_handlers=True, always_connect=False, + **kwargs): engineio_options = kwargs engineio_logger = engineio_options.pop('engineio_logger', None) if engineio_logger is not None: @@ -119,6 +130,7 @@ class Server(object): self.manager_initialized = False self.async_handlers = async_handlers + self.always_connect = always_connect self.async_mode = self.eio.async_mode @@ -530,17 +542,26 @@ class Server(object): """Handle a client connection request.""" namespace = namespace or '/' self.manager.connect(sid, namespace) - self._send_packet(sid, packet.Packet(packet.CONNECT, - namespace=namespace)) + if self.always_connect: + self._send_packet(sid, packet.Packet(packet.CONNECT, + namespace=namespace)) if self._trigger_event('connect', namespace, sid, self.environ[sid]) is False: - self.manager.pre_disconnect(sid, namespace) - self._send_packet(sid, packet.Packet(packet.DISCONNECT, - namespace=namespace)) + if self.always_connect: + self.manager.pre_disconnect(sid, namespace) + self._send_packet(sid, packet.Packet(packet.DISCONNECT, + namespace=namespace)) self.manager.disconnect(sid, namespace) + if not self.always_connect: + self._send_packet(sid, packet.Packet(packet.ERROR, + namespace=namespace)) + if sid in self.environ: # pragma: no cover del self.environ[sid] return False + elif not self.always_connect: + self._send_packet(sid, packet.Packet(packet.CONNECT, + namespace=namespace)) def _handle_disconnect(self, sid, namespace): """Handle a client disconnect.""" diff --git a/tests/asyncio/test_asyncio_server.py b/tests/asyncio/test_asyncio_server.py index 1df7f89..6c84154 100644 --- a/tests/asyncio/test_asyncio_server.py +++ b/tests/asyncio/test_asyncio_server.py @@ -295,8 +295,7 @@ class TestAsyncServer(unittest.TestCase): self.assertEqual(s.manager.connect.call_count, 1) self.assertEqual(s.manager.disconnect.call_count, 1) self.assertEqual(s.environ, {}) - s.eio.send.mock.assert_any_call('123', '0', binary=False) - s.eio.send.mock.assert_any_call('123', '1', binary=False) + s.eio.send.mock.assert_called_once_with('123', '4', binary=False) def test_handle_connect_namespace_rejected(self, eio): eio.return_value.send = AsyncMock() @@ -309,6 +308,35 @@ class TestAsyncServer(unittest.TestCase): self.assertEqual(s.manager.connect.call_count, 2) self.assertEqual(s.manager.disconnect.call_count, 1) self.assertEqual(s.environ, {}) + s.eio.send.mock.assert_any_call('123', '4/foo', binary=False) + + def test_handle_connect_rejected_always_connect(self, eio): + eio.return_value.send = AsyncMock() + mgr = self._get_mock_manager() + s = asyncio_server.AsyncServer(client_manager=mgr, + always_connect=True) + handler = mock.MagicMock(return_value=False) + s.on('connect', handler) + _run(s._handle_eio_connect('123', 'environ')) + handler.assert_called_once_with('123', 'environ') + self.assertEqual(s.manager.connect.call_count, 1) + self.assertEqual(s.manager.disconnect.call_count, 1) + self.assertEqual(s.environ, {}) + s.eio.send.mock.assert_any_call('123', '0', binary=False) + s.eio.send.mock.assert_any_call('123', '1', binary=False) + + def test_handle_connect_namespace_rejected_always_connect(self, eio): + eio.return_value.send = AsyncMock() + mgr = self._get_mock_manager() + s = asyncio_server.AsyncServer(client_manager=mgr, + always_connect=True) + handler = mock.MagicMock(return_value=False) + s.on('connect', handler, namespace='/foo') + _run(s._handle_eio_connect('123', 'environ')) + _run(s._handle_eio_message('123', '0/foo')) + self.assertEqual(s.manager.connect.call_count, 2) + self.assertEqual(s.manager.disconnect.call_count, 1) + self.assertEqual(s.environ, {}) s.eio.send.mock.assert_any_call('123', '0/foo', binary=False) s.eio.send.mock.assert_any_call('123', '1/foo', binary=False) diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 0b278aa..1f8a70c 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -239,8 +239,7 @@ class TestServer(unittest.TestCase): self.assertEqual(s.manager.connect.call_count, 1) self.assertEqual(s.manager.disconnect.call_count, 1) self.assertEqual(s.environ, {}) - s.eio.send.assert_any_call('123', '0', binary=False) - s.eio.send.assert_any_call('123', '1', binary=False) + s.eio.send.assert_called_once_with('123', '4', binary=False) def test_handle_connect_namespace_rejected(self, eio): mgr = mock.MagicMock() @@ -251,6 +250,30 @@ class TestServer(unittest.TestCase): s._handle_eio_message('123', '0/foo') self.assertEqual(s.manager.connect.call_count, 2) self.assertEqual(s.manager.disconnect.call_count, 1) + s.eio.send.assert_any_call('123', '4/foo', binary=False) + + def test_handle_connect_rejected_always_connect(self, eio): + mgr = mock.MagicMock() + s = server.Server(client_manager=mgr, always_connect=True) + handler = mock.MagicMock(return_value=False) + s.on('connect', handler) + s._handle_eio_connect('123', 'environ') + handler.assert_called_once_with('123', 'environ') + self.assertEqual(s.manager.connect.call_count, 1) + self.assertEqual(s.manager.disconnect.call_count, 1) + self.assertEqual(s.environ, {}) + s.eio.send.assert_any_call('123', '0', binary=False) + s.eio.send.assert_any_call('123', '1', binary=False) + + def test_handle_connect_namespace_rejected_always_connect(self, eio): + mgr = mock.MagicMock() + s = server.Server(client_manager=mgr, always_connect=True) + handler = mock.MagicMock(return_value=False) + s.on('connect', handler, namespace='/foo') + s._handle_eio_connect('123', 'environ') + s._handle_eio_message('123', '0/foo') + self.assertEqual(s.manager.connect.call_count, 2) + self.assertEqual(s.manager.disconnect.call_count, 1) s.eio.send.assert_any_call('123', '0/foo', binary=False) s.eio.send.assert_any_call('123', '1/foo', binary=False)