From f752312884332e86946df0c4c60ee52dd0590164 Mon Sep 17 00:00:00 2001
From: Miguel Grinberg <miguel.grinberg@gmail.com>
Date: Wed, 27 Feb 2019 19:53:14 +0000
Subject: [PATCH] Avoid double calls to client disconnect handlers

Fixes #261
---
 socketio/asyncio_client.py           |  4 ++--
 socketio/client.py                   |  6 +++---
 tests/asyncio/test_asyncio_client.py | 15 ++++++++-------
 tests/common/test_client.py          | 15 +++++++++------
 4 files changed, 22 insertions(+), 18 deletions(-)

diff --git a/socketio/asyncio_client.py b/socketio/asyncio_client.py
index 88a6e27..8489e35 100644
--- a/socketio/asyncio_client.py
+++ b/socketio/asyncio_client.py
@@ -235,11 +235,11 @@ class AsyncClient(client.Client):
 
         Note: this method is a coroutine.
         """
+        # here we just request the disconnection
+        # later in _handle_eio_disconnect we invoke the disconnect handler
         for n in self.namespaces:
-            await self._trigger_event('disconnect', namespace=n)
             await self._send_packet(packet.Packet(packet.DISCONNECT,
                                     namespace=n))
-        await self._trigger_event('disconnect', namespace='/')
         await self._send_packet(packet.Packet(
             packet.DISCONNECT, namespace='/'))
         await self.eio.disconnect(abort=True)
diff --git a/socketio/client.py b/socketio/client.py
index 84e13ef..2e2f242 100644
--- a/socketio/client.py
+++ b/socketio/client.py
@@ -317,7 +317,7 @@ class Client(object):
         def event_callback(*args):
             callback_args.append(args)
             callback_event.set()
-        
+
         self.emit(event, data=data, namespace=namespace,
                   callback=event_callback)
         if not callback_event.wait(timeout=timeout):
@@ -328,10 +328,10 @@ class Client(object):
 
     def disconnect(self):
         """Disconnect from the server."""
+        # here we just request the disconnection
+        # later in _handle_eio_disconnect we invoke the disconnect handler
         for n in self.namespaces:
-            self._trigger_event('disconnect', namespace=n)
             self._send_packet(packet.Packet(packet.DISCONNECT, namespace=n))
-        self._trigger_event('disconnect', namespace='/')
         self._send_packet(packet.Packet(
             packet.DISCONNECT, namespace='/'))
         self.eio.disconnect(abort=True)
diff --git a/tests/asyncio/test_asyncio_client.py b/tests/asyncio/test_asyncio_client.py
index 3114443..79556f5 100644
--- a/tests/asyncio/test_asyncio_client.py
+++ b/tests/asyncio/test_asyncio_client.py
@@ -291,9 +291,9 @@ class TestAsyncClient(unittest.TestCase):
         c._send_packet = AsyncMock()
         c.eio = mock.MagicMock()
         c.eio.disconnect = AsyncMock()
+        c.eio.state = 'connected'
         _run(c.disconnect())
-        c._trigger_event.mock.assert_called_once_with(
-            'disconnect', namespace='/')
+        self.assertEqual(c._trigger_event.mock.call_count, 0)
         self.assertEqual(c._send_packet.mock.call_count, 1)
         expected_packet = packet.Packet(packet.DISCONNECT, namespace='/')
         self.assertEqual(c._send_packet.mock.call_args_list[0][0][0].encode(),
@@ -305,12 +305,11 @@ class TestAsyncClient(unittest.TestCase):
         c.namespaces = ['/foo', '/bar']
         c._trigger_event = AsyncMock()
         c._send_packet = AsyncMock()
+        c.eio = mock.MagicMock()
+        c.eio.disconnect = AsyncMock()
+        c.eio.state = 'connected'
         _run(c.disconnect())
-        self.assertEqual(c._trigger_event.mock.call_args_list, [
-            mock.call('disconnect', namespace='/foo'),
-            mock.call('disconnect', namespace='/bar'),
-            mock.call('disconnect', namespace='/')
-        ])
+        self.assertEqual(c._trigger_event.mock.call_count, 0)
         self.assertEqual(c._send_packet.mock.call_count, 3)
         expected_packet = packet.Packet(packet.DISCONNECT, namespace='/foo')
         self.assertEqual(c._send_packet.mock.call_args_list[0][0][0].encode(),
@@ -632,6 +631,7 @@ class TestAsyncClient(unittest.TestCase):
     def test_eio_disconnect(self):
         c = asyncio_client.AsyncClient()
         c._trigger_event = AsyncMock()
+        c.eio.state = 'connected'
         _run(c._handle_eio_disconnect())
         c._trigger_event.mock.assert_called_once_with(
             'disconnect', namespace='/')
@@ -640,6 +640,7 @@ class TestAsyncClient(unittest.TestCase):
         c = asyncio_client.AsyncClient()
         c.namespaces = ['/foo', '/bar']
         c._trigger_event = AsyncMock()
+        c.eio.state = 'connected'
         _run(c._handle_eio_disconnect())
         c._trigger_event.mock.assert_any_call('disconnect', namespace='/foo')
         c._trigger_event.mock.assert_any_call('disconnect', namespace='/bar')
diff --git a/tests/common/test_client.py b/tests/common/test_client.py
index 8fd078d..b23b7d7 100644
--- a/tests/common/test_client.py
+++ b/tests/common/test_client.py
@@ -376,8 +376,9 @@ class TestClient(unittest.TestCase):
         c._trigger_event = mock.MagicMock()
         c._send_packet = mock.MagicMock()
         c.eio = mock.MagicMock()
+        c.eio.state = 'connected'
         c.disconnect()
-        c._trigger_event.assert_called_once_with('disconnect', namespace='/')
+        self.assertEqual(c._trigger_event.call_count, 0)
         self.assertEqual(c._send_packet.call_count, 1)
         expected_packet = packet.Packet(packet.DISCONNECT, namespace='/')
         self.assertEqual(c._send_packet.call_args_list[0][0][0].encode(),
@@ -389,12 +390,10 @@ class TestClient(unittest.TestCase):
         c.namespaces = ['/foo', '/bar']
         c._trigger_event = mock.MagicMock()
         c._send_packet = mock.MagicMock()
+        c.eio = mock.MagicMock()
+        c.eio.state = 'connected'
         c.disconnect()
-        self.assertEqual(c._trigger_event.call_args_list, [
-            mock.call('disconnect', namespace='/foo'),
-            mock.call('disconnect', namespace='/bar'),
-            mock.call('disconnect', namespace='/')
-        ])
+        self.assertEqual(c._trigger_event.call_count, 0)
         self.assertEqual(c._send_packet.call_count, 3)
         expected_packet = packet.Packet(packet.DISCONNECT, namespace='/foo')
         self.assertEqual(c._send_packet.call_args_list[0][0][0].encode(),
@@ -731,6 +730,8 @@ class TestClient(unittest.TestCase):
     def test_eio_disconnect(self):
         c = client.Client()
         c._trigger_event = mock.MagicMock()
+        c.start_background_task = mock.MagicMock()
+        c.eio.state = 'connected'
         c._handle_eio_disconnect()
         c._trigger_event.assert_called_once_with('disconnect', namespace='/')
 
@@ -738,6 +739,8 @@ class TestClient(unittest.TestCase):
         c = client.Client()
         c.namespaces = ['/foo', '/bar']
         c._trigger_event = mock.MagicMock()
+        c.start_background_task = mock.MagicMock()
+        c.eio.state = 'connected'
         c._handle_eio_disconnect()
         c._trigger_event.assert_any_call('disconnect', namespace='/foo')
         c._trigger_event.assert_any_call('disconnect', namespace='/bar')