Browse Source

handle keyboard interrupt during reconnect (Fixes #301)

pull/319/head
Miguel Grinberg 6 years ago
parent
commit
fa53e3869c
No known key found for this signature in database GPG Key ID: 36848B262DF5F06C
  1. 11
      socketio/asyncio_client.py
  2. 25
      socketio/client.py
  3. 73
      tests/asyncio/test_asyncio_client.py
  4. 32
      tests/common/test_client.py

11
socketio/asyncio_client.py

@ -355,6 +355,8 @@ class AsyncClient(client.Client):
event, *args)
async def _handle_reconnect(self):
self._reconnect_abort.clear()
client.reconnecting_clients.append(self)
attempt_count = 0
current_delay = self.reconnection_delay
while True:
@ -366,7 +368,12 @@ class AsyncClient(client.Client):
self.logger.info(
'Connection failed, new attempt in {:.02f} seconds'.format(
delay))
await self.sleep(delay)
try:
await asyncio.wait_for(self._reconnect_abort.wait(), delay)
self.logger.info('Reconnect task aborted')
break
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
attempt_count += 1
try:
await self.connect(self.connection_url,
@ -385,6 +392,7 @@ class AsyncClient(client.Client):
self.logger.info(
'Maximum reconnection attempts reached, giving up')
break
client.reconnecting_clients.remove(self)
def _handle_eio_connect(self):
"""Handle the Engine.IO connection event."""
@ -422,6 +430,7 @@ class AsyncClient(client.Client):
async def _handle_eio_disconnect(self):
"""Handle the Engine.IO disconnection event."""
self.logger.info('Engine.IO connection dropped')
self._reconnect_abort.set()
for n in self.namespaces:
await self._trigger_event('disconnect', namespace=n)
await self._trigger_event('disconnect', namespace='/')

25
socketio/client.py

@ -1,6 +1,7 @@
import itertools
import logging
import random
import signal
import engineio
import six
@ -10,6 +11,21 @@ from . import namespace
from . import packet
default_logger = logging.getLogger('socketio.client')
reconnecting_clients = []
def signal_handler(sig, frame): # pragma: no cover
"""SIGINT handler.
Notify any clients that are in a reconnect loop to abort. Other
disconnection tasks are handled at the engine.io level.
"""
for client in reconnecting_clients[:]:
client._reconnect_abort.set()
return original_signal_handler(sig, frame)
original_signal_handler = signal.signal(signal.SIGINT, signal_handler)
class Client(object):
@ -102,6 +118,7 @@ class Client(object):
self.callbacks = {}
self._binary_packet = None
self._reconnect_task = None
self._reconnect_abort = self.eio.create_event()
def is_asyncio_based(self):
return False
@ -486,6 +503,8 @@ class Client(object):
event, *args)
def _handle_reconnect(self):
self._reconnect_abort.clear()
reconnecting_clients.append(self)
attempt_count = 0
current_delay = self.reconnection_delay
while True:
@ -497,7 +516,10 @@ class Client(object):
self.logger.info(
'Connection failed, new attempt in {:.02f} seconds'.format(
delay))
self.sleep(delay)
print('***', self._reconnect_abort.wait)
if self._reconnect_abort.wait(delay):
self.logger.info('Reconnect task aborted')
break
attempt_count += 1
try:
self.connect(self.connection_url,
@ -516,6 +538,7 @@ class Client(object):
self.logger.info(
'Maximum reconnection attempts reached, giving up')
break
reconnecting_clients.remove(self)
def _handle_eio_connect(self):
"""Handle the Engine.IO connection event."""

73
tests/asyncio/test_asyncio_client.py

@ -1,4 +1,5 @@
import asyncio
from contextlib import contextmanager
import sys
import unittest
@ -26,6 +27,19 @@ def AsyncMock(*args, **kwargs):
return mock_coro
@contextmanager
def mock_wait_for():
async def fake_wait_for(coro, timeout):
await coro
await fake_wait_for._mock(timeout)
original_wait_for = asyncio.wait_for
asyncio.wait_for = fake_wait_for
fake_wait_for._mock = AsyncMock()
yield
asyncio.wait_for = original_wait_for
def _run(coro):
"""Run the given coroutine."""
return asyncio.get_event_loop().run_until_complete(coro)
@ -542,51 +556,64 @@ class TestAsyncClient(unittest.TestCase):
_run(c._trigger_event('foo', '/', 1, '2'))
self.assertEqual(result, [1, '2'])
@mock.patch('asyncio.wait_for', new_callable=AsyncMock,
side_effect=asyncio.TimeoutError)
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
def test_handle_reconnect(self, random):
def test_handle_reconnect(self, random, wait_for):
c = asyncio_client.AsyncClient()
c._reconnect_task = 'foo'
c.sleep = AsyncMock()
c.connect = AsyncMock(
side_effect=[ValueError, exceptions.ConnectionError, None])
_run(c._handle_reconnect())
self.assertEqual(c.sleep.mock.call_count, 3)
self.assertEqual(c.sleep.mock.call_args_list, [
mock.call(1.5),
mock.call(1.5),
mock.call(4.0)
])
self.assertEqual(wait_for.mock.call_count, 3)
self.assertEqual(
[x[0][1] for x in asyncio.wait_for.mock.call_args_list],
[1.5, 1.5, 4.0])
self.assertEqual(c._reconnect_task, None)
@mock.patch('asyncio.wait_for', new_callable=AsyncMock,
side_effect=asyncio.TimeoutError)
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
def test_handle_reconnect_max_delay(self, random):
def test_handle_reconnect_max_delay(self, random, wait_for):
c = asyncio_client.AsyncClient(reconnection_delay_max=3)
c._reconnect_task = 'foo'
c.sleep = AsyncMock()
c.connect = AsyncMock(
side_effect=[ValueError, exceptions.ConnectionError, None])
_run(c._handle_reconnect())
self.assertEqual(c.sleep.mock.call_count, 3)
self.assertEqual(c.sleep.mock.call_args_list, [
mock.call(1.5),
mock.call(1.5),
mock.call(3.0)
])
self.assertEqual(wait_for.mock.call_count, 3)
self.assertEqual(
[x[0][1] for x in asyncio.wait_for.mock.call_args_list],
[1.5, 1.5, 3.0])
self.assertEqual(c._reconnect_task, None)
@mock.patch('asyncio.wait_for', new_callable=AsyncMock,
side_effect=asyncio.TimeoutError)
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
def test_handle_reconnect_max_attempts(self, random):
def test_handle_reconnect_max_attempts(self, random, wait_for):
c = asyncio_client.AsyncClient(reconnection_attempts=2)
c._reconnect_task = 'foo'
c.sleep = AsyncMock()
c.connect = AsyncMock(
side_effect=[ValueError, exceptions.ConnectionError, None])
_run(c._handle_reconnect())
self.assertEqual(c.sleep.mock.call_count, 2)
self.assertEqual(c.sleep.mock.call_args_list, [
mock.call(1.5),
mock.call(1.5)
])
self.assertEqual(wait_for.mock.call_count, 2)
self.assertEqual(
[x[0][1] for x in asyncio.wait_for.mock.call_args_list],
[1.5, 1.5])
self.assertEqual(c._reconnect_task, 'foo')
@mock.patch('asyncio.wait_for', new_callable=AsyncMock,
side_effect=[asyncio.TimeoutError, None])
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
def test_handle_reconnect_aborted(self, random, wait_for):
c = asyncio_client.AsyncClient()
c._reconnect_task = 'foo'
c.connect = AsyncMock(
side_effect=[ValueError, exceptions.ConnectionError, None])
_run(c._handle_reconnect())
self.assertEqual(wait_for.mock.call_count, 2)
self.assertEqual(
[x[0][1] for x in asyncio.wait_for.mock.call_args_list],
[1.5, 1.5])
self.assertEqual(c._reconnect_task, 'foo')
def test_eio_connect(self):

32
tests/common/test_client.py

@ -671,12 +671,12 @@ class TestClient(unittest.TestCase):
def test_handle_reconnect(self, random):
c = client.Client()
c._reconnect_task = 'foo'
c.sleep = mock.MagicMock()
c._reconnect_abort.wait = mock.MagicMock(return_value=False)
c.connect = mock.MagicMock(
side_effect=[ValueError, exceptions.ConnectionError, None])
c._handle_reconnect()
self.assertEqual(c.sleep.call_count, 3)
self.assertEqual(c.sleep.call_args_list, [
self.assertEqual(c._reconnect_abort.wait.call_count, 3)
self.assertEqual(c._reconnect_abort.wait.call_args_list, [
mock.call(1.5),
mock.call(1.5),
mock.call(4.0)
@ -687,12 +687,12 @@ class TestClient(unittest.TestCase):
def test_handle_reconnect_max_delay(self, random):
c = client.Client(reconnection_delay_max=3)
c._reconnect_task = 'foo'
c.sleep = mock.MagicMock()
c._reconnect_abort.wait = mock.MagicMock(return_value=False)
c.connect = mock.MagicMock(
side_effect=[ValueError, exceptions.ConnectionError, None])
c._handle_reconnect()
self.assertEqual(c.sleep.call_count, 3)
self.assertEqual(c.sleep.call_args_list, [
self.assertEqual(c._reconnect_abort.wait.call_count, 3)
self.assertEqual(c._reconnect_abort.wait.call_args_list, [
mock.call(1.5),
mock.call(1.5),
mock.call(3.0)
@ -703,12 +703,26 @@ class TestClient(unittest.TestCase):
def test_handle_reconnect_max_attempts(self, random):
c = client.Client(reconnection_attempts=2)
c._reconnect_task = 'foo'
c.sleep = mock.MagicMock()
c._reconnect_abort.wait = mock.MagicMock(return_value=False)
c.connect = mock.MagicMock(
side_effect=[ValueError, exceptions.ConnectionError, None])
c._handle_reconnect()
self.assertEqual(c.sleep.call_count, 2)
self.assertEqual(c.sleep.call_args_list, [
self.assertEqual(c._reconnect_abort.wait.call_count, 2)
self.assertEqual(c._reconnect_abort.wait.call_args_list, [
mock.call(1.5),
mock.call(1.5)
])
self.assertEqual(c._reconnect_task, 'foo')
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
def test_handle_reconnect_aborted(self, random):
c = client.Client()
c._reconnect_task = 'foo'
c._reconnect_abort.wait = mock.MagicMock(side_effect=[False, True])
c.connect = mock.MagicMock(side_effect=exceptions.ConnectionError)
c._handle_reconnect()
self.assertEqual(c._reconnect_abort.wait.call_count, 2)
self.assertEqual(c._reconnect_abort.wait.call_args_list, [
mock.call(1.5),
mock.call(1.5)
])

Loading…
Cancel
Save