From 63ddde0fb01afdd9f9708eb7852d6d322d5a46d4 Mon Sep 17 00:00:00 2001
From: Miguel Grinberg <miguel.grinberg@gmail.com>
Date: Sun, 16 Dec 2018 14:30:00 +0000
Subject: [PATCH] client reconnection support

---
 socketio/client.py     | 96 +++++++++++++++++++++++++++++++++++-------
 socketio/exceptions.py |  6 +++
 2 files changed, 86 insertions(+), 16 deletions(-)
 create mode 100644 socketio/exceptions.py

diff --git a/socketio/client.py b/socketio/client.py
index de350b5..0d21e90 100644
--- a/socketio/client.py
+++ b/socketio/client.py
@@ -1,9 +1,11 @@
 import itertools
 import logging
+import random
 
 import engineio
 import six
 
+from . import exceptions
 from . import namespace
 from . import packet
 
@@ -55,6 +57,13 @@ class Client(object):
                  reconnection_delay=1, reconnection_delay_max=5,
                  randomization_factor=0.5, logger=False, binary=False,
                  json=None, **kwargs):
+        self.reconnection = reconnection
+        self.reconnection_attempts = reconnection_attempts
+        self.reconnection_delay = reconnection_delay
+        self.reconnection_delay_max = reconnection_delay_max
+        self.randomization_factor = randomization_factor
+        self.binary = binary
+
         engineio_options = kwargs
         engineio_logger = engineio_options.pop('engineio_logger', None)
         if engineio_logger is not None:
@@ -67,12 +76,6 @@ class Client(object):
         self.eio.on('connect', self._handle_eio_connect)
         self.eio.on('message', self._handle_eio_message)
         self.eio.on('disconnect', self._handle_eio_disconnect)
-        self.binary = binary
-        self.namespaces = None
-        self.handlers = {}
-        self.namespace_handlers = {}
-        self.callbacks = {}
-        self._binary_packet = None
 
         if not isinstance(logger, bool):
             self.logger = logger
@@ -86,6 +89,19 @@ class Client(object):
                     self.logger.setLevel(logging.ERROR)
                 self.logger.addHandler(logging.StreamHandler())
 
+        self.connection_url = None
+        self.connection_headers = None
+        self.connection_transports = None
+        self.connection_namespaces = None
+        self.socketio_path = None
+
+        self.namespaces = None
+        self.handlers = {}
+        self.namespace_handlers = {}
+        self.callbacks = {}
+        self._binary_packet = None
+        self._reconnect_task = None
+
     def is_asyncio_based(self):
         return False
 
@@ -174,12 +190,21 @@ class Client(object):
             sio = socketio.Client()
             sio.connect('http://localhost:5000')
         """
+        self.connection_url = url
+        self.connection_headers = headers
+        self.connection_transports = transports
+        self.connection_namespaces = namespaces
+        self.socketio_path = socketio_path
+
         if namespaces is None:
             namespaces = set(self.handlers.keys()).union(
                 set(self.namespace_handlers.keys()))
         self.namespaces = [n for n in namespaces if n != '/']
-        self.eio.connect(url, headers=headers, transports=transports,
-                         engineio_path=socketio_path)
+        try:
+            self.eio.connect(url, headers=headers, transports=transports,
+                             engineio_path=socketio_path)
+        except engineio.exceptions.ConnectionError as exc:
+            six.raise_from(exceptions.ConnectionError(exc.args[0]), None)
 
     def wait(self):
         """Wait until the connection with the server ends.
@@ -187,7 +212,13 @@ class Client(object):
         Client applications can use this function to block the main thread
         during the life of the connection.
         """
-        self.eio.wait()
+        while True:
+            self.eio.wait()
+            if not self._reconnect_task:
+                break
+            self._reconnect_task.join()
+            if self.eio.state != 'connected':
+                break
 
     def emit(self, event, data=None, namespace=None, callback=None):
         """Emit a custom event to one or more connected clients.
@@ -208,7 +239,7 @@ class Client(object):
                          when addressing an individual client.
         """
         namespace = namespace or '/'
-        self.logger.info('emitting event "%s" [%s]', event, namespace)
+        self.logger.info('Emitting event "%s" [%s]', event, namespace)
         if callback is not None:
             id = self._generate_ack_id(namespace, callback)
         else:
@@ -317,7 +348,7 @@ class Client(object):
 
     def _handle_connect(self, namespace):
         namespace = namespace or '/'
-        self.logger.info('namespace {} is connected'.format(namespace))
+        self.logger.info('Namespace {} is connected'.format(namespace))
         self._trigger_event('connect', namespace=namespace)
         if namespace == '/':
             for n in self.namespaces:
@@ -331,7 +362,7 @@ class Client(object):
 
     def _handle_event(self, namespace, id, data):
         namespace = namespace or '/'
-        self.logger.info('received event "%s" [%s]', data[0], namespace)
+        self.logger.info('Received event "%s" [%s]', data[0], namespace)
         self._handle_event_internal(data, namespace, id)
 
     def _handle_event_internal(self, data, namespace, id):
@@ -354,7 +385,7 @@ class Client(object):
 
     def _handle_ack(self, namespace, id, data):
         namespace = namespace or '/'
-        self.logger.info('received ack [%s]', namespace)
+        self.logger.info('Received ack [%s]', namespace)
         callback = None
         try:
             callback = self.callbacks[namespace][id]
@@ -368,7 +399,7 @@ class Client(object):
 
     def _handle_error(self, namespace, data):
         namespace = namespace or '/'
-        self.logger.info('connection to namespace {} was rejected'.format(
+        self.logger.info('Connection to namespace {} was rejected'.format(
             namespace))
         if namespace in self.namespaces:
             self.namespaces.remove(namespace)
@@ -384,9 +415,39 @@ class Client(object):
             return self.namespace_handlers[namespace].trigger_event(
                 event, *args)
 
+    def _handle_reconnect(self):
+        attempt_count = 0
+        current_delay = self.reconnection_delay
+        while True:
+            delay = current_delay
+            current_delay *= 2
+            if delay > self.reconnection_delay_max:
+                delay = self.reconnection_delay_max
+            delay += self.randomization_factor * (2 * random.random() - 1)
+            self.logger.info(
+                'Connection failed, new attempt in {:.02f} seconds'.format(
+                    delay))
+            self.sleep(delay)
+            attempt_count += 1
+            try:
+                self.connect(self.connection_url,
+                             headers=self.connection_headers,
+                             transports=self.connection_transports,
+                             socketio_path=self.socketio_path)
+            except (exceptions.ConnectionError, ValueError):
+                pass
+            else:
+                self.logger.info('Reconnection successful')
+                break
+            if self.reconnection_attempts and \
+                    attempt_count >= self.reconnection_attempts:
+                self.logger.info(
+                    'Maximum reconnection attempts reached, giving up')
+                break
+
     def _handle_eio_connect(self):
         """Handle the Engine.IO connection event."""
-        self.logger.info('engine.io connection established')
+        self.logger.info('Engine.IO connection established')
 
     def _handle_eio_message(self, data):
         """Dispatch Engine.IO messages."""
@@ -418,12 +479,15 @@ class Client(object):
 
     def _handle_eio_disconnect(self):
         """Handle the Engine.IO disconnection event."""
-        self.logger.info('engine.io connection dropped')
+        self.logger.info('Engine.IO connection dropped')
         for n in self.namespaces:
             self._trigger_event('disconnect', namespace=n)
         self._trigger_event('disconnect', namespace='/')
         self.callbacks = {}
         self._binary_packet = None
+        if self.eio.state == 'connected' and self.reconnection:
+            self._reconnect_task = self.start_background_task(
+                self._handle_reconnect)
 
     def _engineio_client_class(self):
         return engineio.Client
diff --git a/socketio/exceptions.py b/socketio/exceptions.py
new file mode 100644
index 0000000..5bd8697
--- /dev/null
+++ b/socketio/exceptions.py
@@ -0,0 +1,6 @@
+class SocketIOError(Exception):
+    pass
+
+
+class ConnectionError(SocketIOError):
+    pass