diff --git a/socketio/asyncio_manager.py b/socketio/asyncio_manager.py index 08b02d7..01bda69 100644 --- a/socketio/asyncio_manager.py +++ b/socketio/asyncio_manager.py @@ -44,7 +44,7 @@ class AsyncManager(BaseManager): callback = self.callbacks[sid][namespace][id] except KeyError: # if we get an unknown callback we just ignore it - self.server.logger.warning('Unknown callback received, ignoring.') + self._get_logger().warning('Unknown callback received, ignoring.') else: del self.callbacks[sid][namespace][id] if callback is not None: diff --git a/socketio/asyncio_pubsub_manager.py b/socketio/asyncio_pubsub_manager.py index 578e734..6fdba6d 100644 --- a/socketio/asyncio_pubsub_manager.py +++ b/socketio/asyncio_pubsub_manager.py @@ -24,17 +24,18 @@ class AsyncPubSubManager(AsyncManager): """ name = 'asyncpubsub' - def __init__(self, channel='socketio', write_only=False): + def __init__(self, channel='socketio', write_only=False, logger=None): super().__init__() self.channel = channel self.write_only = write_only self.host_id = uuid.uuid4().hex + self.logger = logger def initialize(self): super().initialize() if not self.write_only: self.thread = self.server.start_background_task(self._thread) - self.server.logger.info(self.name + ' backend initialized.') + self._get_logger().info(self.name + ' backend initialized.') async def emit(self, event, data, namespace=None, room=None, skip_sid=None, callback=None, **kwargs): diff --git a/socketio/asyncio_redis_manager.py b/socketio/asyncio_redis_manager.py index 7cb2a53..2265937 100644 --- a/socketio/asyncio_redis_manager.py +++ b/socketio/asyncio_redis_manager.py @@ -1,5 +1,4 @@ import asyncio -import logging import pickle from urllib.parse import urlparse @@ -10,8 +9,6 @@ except ImportError: from .asyncio_pubsub_manager import AsyncPubSubManager -logger = logging.getLogger('socketio') - def _parse_redis_url(url): p = urlparse(url) @@ -52,7 +49,7 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover name = 'aioredis' def __init__(self, url='redis://localhost:6379/0', channel='socketio', - write_only=False): + write_only=False, logger=None): if aioredis is None: raise RuntimeError('Redis package is not installed ' '(Run "pip install aioredis" in your ' @@ -60,7 +57,7 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover self.host, self.port, self.password, self.db = _parse_redis_url(url) self.pub = None self.sub = None - super().__init__(channel=channel, write_only=write_only) + super().__init__(channel=channel, write_only=write_only, logger=logger) async def _publish(self, data): retry = True @@ -74,11 +71,13 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover pickle.dumps(data)) except (aioredis.RedisError, OSError): if retry: - logger.error('Cannot publish to redis... retrying') + self._get_logger().error('Cannot publish to redis... ' + 'retrying') self.pub = None retry = False else: - logger.error('Cannot publish to redis... giving up') + self._get_logger().error('Cannot publish to redis... ' + 'giving up') break async def _listen(self): @@ -92,8 +91,9 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover self.ch = (await self.sub.subscribe(self.channel))[0] return await self.ch.get() except (aioredis.RedisError, OSError): - logger.error('Cannot receive from redis... ' - 'retrying in {} secs'.format(retry_sleep)) + self._get_logger().error('Cannot receive from redis... ' + 'retrying in ' + '{} secs'.format(retry_sleep)) self.sub = None await asyncio.sleep(retry_sleep) retry_sleep *= 2 diff --git a/socketio/base_manager.py b/socketio/base_manager.py index 09bef7c..b4bcf5f 100644 --- a/socketio/base_manager.py +++ b/socketio/base_manager.py @@ -1,7 +1,10 @@ import itertools +import logging import six +default_logger = logging.getLogger('socketio') + class BaseManager(object): """Manage client connections. @@ -13,6 +16,7 @@ class BaseManager(object): subclasses. """ def __init__(self): + self.logger = None self.server = None self.rooms = {} self.callbacks = {} @@ -141,7 +145,7 @@ class BaseManager(object): callback = self.callbacks[sid][namespace][id] except KeyError: # if we get an unknown callback we just ignore it - self.server.logger.warning('Unknown callback received, ignoring.') + self._get_logger().warning('Unknown callback received, ignoring.') else: del self.callbacks[sid][namespace][id] if callback is not None: @@ -157,3 +161,16 @@ class BaseManager(object): id = six.next(self.callbacks[sid][namespace][0]) self.callbacks[sid][namespace][id] = callback return id + + def _get_logger(self): + """Get the appropriate logger + + Prevents uninitialized servers in write-only mode from failing. + """ + + if self.logger: + return self.logger + elif self.server: + return self.server.logger + else: + return default_logger diff --git a/socketio/kombu_manager.py b/socketio/kombu_manager.py index 9906673..4394a6d 100644 --- a/socketio/kombu_manager.py +++ b/socketio/kombu_manager.py @@ -38,12 +38,14 @@ class KombuManager(PubSubManager): # pragma: no cover name = 'kombu' def __init__(self, url='amqp://guest:guest@localhost:5672//', - channel='socketio', write_only=False): + channel='socketio', write_only=False, logger=None): if kombu is None: raise RuntimeError('Kombu package is not installed ' '(Run "pip install kombu" in your ' 'virtualenv).') - super(KombuManager, self).__init__(channel=channel) + super(KombuManager, self).__init__(channel=channel, + write_only=write_only, + logger=logger) self.url = url self.producer = self._producer() @@ -78,7 +80,7 @@ class KombuManager(PubSubManager): # pragma: no cover return self._connection().Producer(exchange=self._exchange()) def __error_callback(self, exception, interval): - self.server.logger.exception('Sleeping {}s'.format(interval)) + self._get_logger().exception('Sleeping {}s'.format(interval)) def _publish(self, data): connection = self._connection() @@ -99,5 +101,5 @@ class KombuManager(PubSubManager): # pragma: no cover message.ack() yield message.payload except connection.connection_errors: - self.server.logger.exception("Connection error " + self._get_logger().exception("Connection error " "while reading from queue") diff --git a/socketio/pubsub_manager.py b/socketio/pubsub_manager.py index afbe276..2905b2c 100644 --- a/socketio/pubsub_manager.py +++ b/socketio/pubsub_manager.py @@ -24,17 +24,18 @@ class PubSubManager(BaseManager): """ name = 'pubsub' - def __init__(self, channel='socketio', write_only=False): + def __init__(self, channel='socketio', write_only=False, logger=None): super(PubSubManager, self).__init__() self.channel = channel self.write_only = write_only self.host_id = uuid.uuid4().hex + self.logger = logger def initialize(self): super(PubSubManager, self).initialize() if not self.write_only: self.thread = self.server.start_background_task(self._thread) - self.server.logger.info(self.name + ' backend initialized.') + self._get_logger().info(self.name + ' backend initialized.') def emit(self, event, data, namespace=None, room=None, skip_sid=None, callback=None, **kwargs): diff --git a/socketio/redis_manager.py b/socketio/redis_manager.py index 9a6f499..69be586 100644 --- a/socketio/redis_manager.py +++ b/socketio/redis_manager.py @@ -37,7 +37,7 @@ class RedisManager(PubSubManager): # pragma: no cover name = 'redis' def __init__(self, url='redis://localhost:6379/0', channel='socketio', - write_only=False): + write_only=False, logger=None): if redis is None: raise RuntimeError('Redis package is not installed ' '(Run "pip install redis" in your ' @@ -45,7 +45,8 @@ class RedisManager(PubSubManager): # pragma: no cover self.redis_url = url self._redis_connect() super(RedisManager, self).__init__(channel=channel, - write_only=write_only) + write_only=write_only, + logger=logger) def initialize(self): super(RedisManager, self).initialize() diff --git a/socketio/zmq_manager.py b/socketio/zmq_manager.py index d8995a0..468830b 100644 --- a/socketio/zmq_manager.py +++ b/socketio/zmq_manager.py @@ -50,7 +50,8 @@ class ZmqManager(PubSubManager): # pragma: no cover def __init__(self, url='zmq+tcp://localhost:5555+5556', channel='socketio', - write_only=False): + write_only=False, + logger=None): if zmq is None: raise RuntimeError('zmq package is not installed ' '(Run "pip install pyzmq" in your ' @@ -76,7 +77,8 @@ class ZmqManager(PubSubManager): # pragma: no cover self.sub = sub self.channel = channel super(ZmqManager, self).__init__(channel=channel, - write_only=write_only) + write_only=write_only, + logger=logger) def _publish(self, data): pickled_data = pickle.dumps( diff --git a/tests/test_pubsub_manager.py b/tests/test_pubsub_manager.py index 0430461..295151a 100644 --- a/tests/test_pubsub_manager.py +++ b/tests/test_pubsub_manager.py @@ -1,5 +1,6 @@ import functools import unittest +import logging import six if six.PY3: @@ -39,6 +40,22 @@ class TestBaseManager(unittest.TestCase): self.assertEqual(len(pm.host_id), 32) self.assertEqual(pm.server.start_background_task.call_count, 0) + def test_write_only_default_logger(self): + pm = pubsub_manager.PubSubManager(write_only=True) + pm.initialize() + self.assertEqual(pm.channel, 'socketio') + self.assertEqual(len(pm.host_id), 32) + self.assertEqual(pm._get_logger(), logging.getLogger('socketio')) + + def test_write_only_with_provided_logger(self): + test_logger = logging.getLogger('new_logger') + pm = pubsub_manager.PubSubManager(write_only=True, + logger=test_logger) + pm.initialize() + self.assertEqual(pm.channel, 'socketio') + self.assertEqual(len(pm.host_id), 32) + self.assertEqual(pm._get_logger(), test_logger) + def test_emit(self): self.pm.emit('foo', 'bar') self.pm._publish.assert_called_once_with(