diff --git a/docs/server.rst b/docs/server.rst index 2b1c9ad..b027871 100644 --- a/docs/server.rst +++ b/docs/server.rst @@ -182,7 +182,7 @@ The ``connect`` and ``disconnect`` events are special; they are invoked automatically when a client connects or disconnects from the server:: @sio.event - def connect(sid, environ): + def connect(sid, environ, auth): print('connect ', sid) @sio.event @@ -193,8 +193,10 @@ The ``connect`` event is an ideal place to perform user authentication, and any necessary mapping between user entities in the application and the ``sid`` that was assigned to the client. The ``environ`` argument is a dictionary in standard WSGI format containing the request information, including HTTP -headers. After inspecting the request, the connect event handler can return -``False`` to reject the connection with the client. +headers. The ``auth`` argument contains any authentication details passed by +the client, or ``None`` if the client did not pass anything. After inspecting +the request, the connect event handler can return ``False`` to reject the +connection with the client. Sometimes it is useful to pass data back to the client being rejected. In that case instead of returning ``False`` diff --git a/socketio/asyncio_server.py b/socketio/asyncio_server.py index 0c0b9be..778abc0 100644 --- a/socketio/asyncio_server.py +++ b/socketio/asyncio_server.py @@ -433,7 +433,7 @@ class AsyncServer(server.Server): else: await self.eio.send(eio_sid, encoded_packet) - async def _handle_connect(self, eio_sid, namespace): + async def _handle_connect(self, eio_sid, namespace, data): """Handle a client connection request.""" namespace = namespace or '/' sid = self.manager.connect(eio_sid, namespace) @@ -442,8 +442,16 @@ class AsyncServer(server.Server): packet.CONNECT, {'sid': sid}, namespace=namespace)) fail_reason = exceptions.ConnectionRefusedError().error_args try: - success = await self._trigger_event('connect', namespace, sid, - self.environ[eio_sid]) + if data: + success = await self._trigger_event( + 'connect', namespace, sid, self.environ[eio_sid], data) + else: + try: + success = await self._trigger_event( + 'connect', namespace, sid, self.environ[eio_sid]) + except TypeError: + success = await self._trigger_event( + 'connect', namespace, sid, self.environ[eio_sid], None) except exceptions.ConnectionRefusedError as exc: fail_reason = exc.error_args success = False @@ -552,7 +560,7 @@ class AsyncServer(server.Server): else: pkt = packet.Packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: - await self._handle_connect(eio_sid, pkt.namespace) + await self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: await self._handle_disconnect(eio_sid, pkt.namespace) elif pkt.packet_type == packet.EVENT: diff --git a/socketio/server.py b/socketio/server.py index 854922a..22da0ac 100644 --- a/socketio/server.py +++ b/socketio/server.py @@ -619,7 +619,7 @@ class Server(object): else: self.eio.send(eio_sid, encoded_packet) - def _handle_connect(self, eio_sid, namespace): + def _handle_connect(self, eio_sid, namespace, data): """Handle a client connection request.""" namespace = namespace or '/' sid = self.manager.connect(eio_sid, namespace) @@ -628,8 +628,16 @@ class Server(object): packet.CONNECT, {'sid': sid}, namespace=namespace)) fail_reason = exceptions.ConnectionRefusedError().error_args try: - success = self._trigger_event('connect', namespace, sid, - self.environ[eio_sid]) + if data: + success = self._trigger_event( + 'connect', namespace, sid, self.environ[eio_sid], data) + else: + try: + success = self._trigger_event( + 'connect', namespace, sid, self.environ[eio_sid]) + except TypeError: + success = self._trigger_event( + 'connect', namespace, sid, self.environ[eio_sid], None) except exceptions.ConnectionRefusedError as exc: fail_reason = exc.error_args success = False @@ -729,7 +737,7 @@ class Server(object): else: pkt = packet.Packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: - self._handle_connect(eio_sid, pkt.namespace) + self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: self._handle_disconnect(eio_sid, pkt.namespace) elif pkt.packet_type == packet.EVENT: diff --git a/tests/asyncio/test_asyncio_server.py b/tests/asyncio/test_asyncio_server.py index d21f594..a58f319 100644 --- a/tests/asyncio/test_asyncio_server.py +++ b/tests/asyncio/test_asyncio_server.py @@ -377,6 +377,38 @@ class TestAsyncServer(unittest.TestCase): _run(s._handle_eio_message('456', '0')) assert s.manager.initialize.call_count == 1 + def test_handle_connect_with_auth(self, eio): + eio.return_value.send = AsyncMock() + s = asyncio_server.AsyncServer() + s.manager.initialize = mock.MagicMock() + handler = mock.MagicMock() + s.on('connect', handler) + _run(s._handle_eio_connect('123', 'environ')) + _run(s._handle_eio_message('123', '0{"token":"abc"}')) + assert s.manager.is_connected('1', '/') + handler.assert_called_once_with('1', 'environ', {'token': 'abc'}) + s.eio.send.mock.assert_called_once_with('123', '0{"sid":"1"}') + assert s.manager.initialize.call_count == 1 + _run(s._handle_eio_connect('456', 'environ')) + _run(s._handle_eio_message('456', '0')) + assert s.manager.initialize.call_count == 1 + + def test_handle_connect_with_auth_none(self, eio): + eio.return_value.send = AsyncMock() + s = asyncio_server.AsyncServer() + s.manager.initialize = mock.MagicMock() + handler = mock.MagicMock(side_effect=[TypeError, None, None]) + s.on('connect', handler) + _run(s._handle_eio_connect('123', 'environ')) + _run(s._handle_eio_message('123', '0')) + assert s.manager.is_connected('1', '/') + handler.assert_called_with('1', 'environ', None) + s.eio.send.mock.assert_called_once_with('123', '0{"sid":"1"}') + assert s.manager.initialize.call_count == 1 + _run(s._handle_eio_connect('456', 'environ')) + _run(s._handle_eio_message('456', '0')) + assert s.manager.initialize.call_count == 1 + def test_handle_connect_async(self, eio): eio.return_value.send = AsyncMock() s = asyncio_server.AsyncServer() diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 2b0f3db..e6be2ac 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -326,6 +326,34 @@ class TestServer(unittest.TestCase): s._handle_eio_connect('456', 'environ') assert s.manager.initialize.call_count == 1 + def test_handle_connect_with_auth(self, eio): + s = server.Server() + s.manager.initialize = mock.MagicMock() + handler = mock.MagicMock() + s.on('connect', handler) + s._handle_eio_connect('123', 'environ') + s._handle_eio_message('123', '0{"token":"abc"}') + assert s.manager.is_connected('1', '/') + handler.assert_called_with('1', 'environ', {'token': 'abc'}) + s.eio.send.assert_called_once_with('123', '0{"sid":"1"}') + assert s.manager.initialize.call_count == 1 + s._handle_eio_connect('456', 'environ') + assert s.manager.initialize.call_count == 1 + + def test_handle_connect_with_auth_none(self, eio): + s = server.Server() + s.manager.initialize = mock.MagicMock() + handler = mock.MagicMock(side_effect=[TypeError, None]) + s.on('connect', handler) + s._handle_eio_connect('123', 'environ') + s._handle_eio_message('123', '0') + assert s.manager.is_connected('1', '/') + handler.assert_called_with('1', 'environ', None) + s.eio.send.assert_called_once_with('123', '0{"sid":"1"}') + assert s.manager.initialize.call_count == 1 + s._handle_eio_connect('456', 'environ') + assert s.manager.initialize.call_count == 1 + def test_handle_connect_namespace(self, eio): s = server.Server() handler = mock.MagicMock()