From 0db35c87e98ed311c8e6cf4c8142bfda9b86949c Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Sat, 9 Mar 2019 18:37:40 +0000 Subject: [PATCH] send connect packet before invoking connect handler --- socketio/asyncio_server.py | 14 +++++++------- socketio/server.py | 12 ++++++------ tests/asyncio/test_asyncio_server.py | 6 ++++-- tests/common/test_server.py | 6 ++++-- 4 files changed, 21 insertions(+), 17 deletions(-) diff --git a/socketio/asyncio_server.py b/socketio/asyncio_server.py index 144b79c..99919ab 100644 --- a/socketio/asyncio_server.py +++ b/socketio/asyncio_server.py @@ -191,7 +191,7 @@ class AsyncServer(server.Server): def event_callback(*args): callback_args.append(args) callback_event.set() - + await self.emit(event, data=data, room=sid, namespace=namespace, callback=event_callback, **kwargs) try: @@ -201,7 +201,7 @@ class AsyncServer(server.Server): return callback_args[0] if len(callback_args[0]) > 1 \ else callback_args[0][0] if len(callback_args[0]) == 1 \ else None - + async def close_room(self, room, namespace=None): """Close a room. @@ -368,17 +368,17 @@ class AsyncServer(server.Server): """Handle a client connection request.""" namespace = namespace or '/' self.manager.connect(sid, namespace) + await self._send_packet(sid, packet.Packet(packet.CONNECT, + namespace=namespace)) if await self._trigger_event('connect', namespace, sid, self.environ[sid]) is False: - self.manager.disconnect(sid, namespace) - await self._send_packet(sid, packet.Packet(packet.ERROR, + self.manager.pre_disconnect(sid, namespace) + await self._send_packet(sid, packet.Packet(packet.DISCONNECT, namespace=namespace)) + self.manager.disconnect(sid, namespace) if sid in self.environ: # pragma: no cover del self.environ[sid] return False - else: - await self._send_packet(sid, packet.Packet(packet.CONNECT, - namespace=namespace)) async def _handle_disconnect(self, sid, namespace): """Handle a client disconnect.""" diff --git a/socketio/server.py b/socketio/server.py index 8157411..351d849 100644 --- a/socketio/server.py +++ b/socketio/server.py @@ -300,7 +300,7 @@ class Server(object): def event_callback(*args): callback_args.append(args) callback_event.set() - + self.emit(event, data=data, room=sid, namespace=namespace, callback=event_callback, **kwargs) if not callback_event.wait(timeout=timeout): @@ -530,17 +530,17 @@ class Server(object): """Handle a client connection request.""" namespace = namespace or '/' self.manager.connect(sid, namespace) + self._send_packet(sid, packet.Packet(packet.CONNECT, + namespace=namespace)) if self._trigger_event('connect', namespace, sid, self.environ[sid]) is False: - self.manager.disconnect(sid, namespace) - self._send_packet(sid, packet.Packet(packet.ERROR, + self.manager.pre_disconnect(sid, namespace) + self._send_packet(sid, packet.Packet(packet.DISCONNECT, namespace=namespace)) + self.manager.disconnect(sid, namespace) if sid in self.environ: # pragma: no cover del self.environ[sid] return False - else: - self._send_packet(sid, packet.Packet(packet.CONNECT, - namespace=namespace)) def _handle_disconnect(self, sid, namespace): """Handle a client disconnect.""" diff --git a/tests/asyncio/test_asyncio_server.py b/tests/asyncio/test_asyncio_server.py index fc855c6..1df7f89 100644 --- a/tests/asyncio/test_asyncio_server.py +++ b/tests/asyncio/test_asyncio_server.py @@ -295,7 +295,8 @@ class TestAsyncServer(unittest.TestCase): self.assertEqual(s.manager.connect.call_count, 1) self.assertEqual(s.manager.disconnect.call_count, 1) self.assertEqual(s.environ, {}) - s.eio.send.mock.assert_called_once_with('123', '4', binary=False) + s.eio.send.mock.assert_any_call('123', '0', binary=False) + s.eio.send.mock.assert_any_call('123', '1', binary=False) def test_handle_connect_namespace_rejected(self, eio): eio.return_value.send = AsyncMock() @@ -308,7 +309,8 @@ class TestAsyncServer(unittest.TestCase): self.assertEqual(s.manager.connect.call_count, 2) self.assertEqual(s.manager.disconnect.call_count, 1) self.assertEqual(s.environ, {}) - s.eio.send.mock.assert_any_call('123', '4/foo', binary=False) + s.eio.send.mock.assert_any_call('123', '0/foo', binary=False) + s.eio.send.mock.assert_any_call('123', '1/foo', binary=False) def test_handle_disconnect(self, eio): eio.return_value.send = AsyncMock() diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 6a54f6f..0b278aa 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -239,7 +239,8 @@ class TestServer(unittest.TestCase): self.assertEqual(s.manager.connect.call_count, 1) self.assertEqual(s.manager.disconnect.call_count, 1) self.assertEqual(s.environ, {}) - s.eio.send.assert_called_once_with('123', '4', binary=False) + s.eio.send.assert_any_call('123', '0', binary=False) + s.eio.send.assert_any_call('123', '1', binary=False) def test_handle_connect_namespace_rejected(self, eio): mgr = mock.MagicMock() @@ -250,7 +251,8 @@ class TestServer(unittest.TestCase): s._handle_eio_message('123', '0/foo') self.assertEqual(s.manager.connect.call_count, 2) self.assertEqual(s.manager.disconnect.call_count, 1) - s.eio.send.assert_any_call('123', '4/foo', binary=False) + s.eio.send.assert_any_call('123', '0/foo', binary=False) + s.eio.send.assert_any_call('123', '1/foo', binary=False) def test_handle_disconnect(self, eio): mgr = mock.MagicMock()