diff --git a/src/socketio/asyncio_aiopika_manager.py b/src/socketio/asyncio_aiopika_manager.py index 905057d..96dcec6 100644 --- a/src/socketio/asyncio_aiopika_manager.py +++ b/src/socketio/asyncio_aiopika_manager.py @@ -94,7 +94,7 @@ class AsyncAioPikaManager(AsyncPubSubManager): # pragma: no cover async with self.listener_queue.iterator() as queue_iter: async for message in queue_iter: with message.process(): - return pickle.loads(message.body) + yield pickle.loads(message.body) except Exception: self._get_logger().error('Cannot receive from rabbitmq... ' 'retrying in ' diff --git a/src/socketio/asyncio_pubsub_manager.py b/src/socketio/asyncio_pubsub_manager.py index 916c4a6..ac569c8 100644 --- a/src/socketio/asyncio_pubsub_manager.py +++ b/src/socketio/asyncio_pubsub_manager.py @@ -148,35 +148,35 @@ class AsyncPubSubManager(AsyncManager): async def _thread(self): while True: try: - message = await self._listen() + async for message in self._listen(): + data = None + if isinstance(message, dict): + data = message + else: + if isinstance(message, bytes): # pragma: no cover + try: + data = pickle.loads(message) + except: + pass + if data is None: + try: + data = json.loads(message) + except: + pass + if data and 'method' in data: + self._get_logger().info('pubsub message: {}'.format( + data['method'])) + if data['method'] == 'emit': + await self._handle_emit(data) + elif data['method'] == 'callback': + await self._handle_callback(data) + elif data['method'] == 'disconnect': + await self._handle_disconnect(data) + elif data['method'] == 'close_room': + await self._handle_close_room(data) except asyncio.CancelledError: # pragma: no cover break except: import traceback traceback.print_exc() break - data = None - if isinstance(message, dict): - data = message - else: - if isinstance(message, bytes): # pragma: no cover - try: - data = pickle.loads(message) - except: - pass - if data is None: - try: - data = json.loads(message) - except: - pass - if data and 'method' in data: - self._get_logger().info('pubsub message: {}'.format( - data['method'])) - if data['method'] == 'emit': - await self._handle_emit(data) - elif data['method'] == 'callback': - await self._handle_callback(data) - elif data['method'] == 'disconnect': - await self._handle_disconnect(data) - elif data['method'] == 'close_room': - await self._handle_close_room(data) diff --git a/src/socketio/asyncio_redis_manager.py b/src/socketio/asyncio_redis_manager.py index 9762d3e..7f96e8d 100644 --- a/src/socketio/asyncio_redis_manager.py +++ b/src/socketio/asyncio_redis_manager.py @@ -10,21 +10,6 @@ except ImportError: from .asyncio_pubsub_manager import AsyncPubSubManager -def _parse_redis_url(url): - p = urlparse(url) - if p.scheme not in {'redis', 'rediss'}: - raise ValueError('Invalid redis url') - ssl = p.scheme == 'rediss' - host = p.hostname or 'localhost' - port = p.port or 6379 - password = p.password - if p.path: - db = int(p.path[1:]) - else: - db = 0 - return host, port, password, db, ssl - - class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover """Redis based client manager for asyncio servers. @@ -51,58 +36,41 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover name = 'aioredis' def __init__(self, url='redis://localhost:6379/0', channel='socketio', - write_only=False, logger=None): + write_only=False, logger=None, redis_options={}): if aioredis is None: raise RuntimeError('Redis package is not installed ' '(Run "pip install aioredis" in your ' 'virtualenv).') - ( - self.host, self.port, self.password, self.db, self.ssl - ) = _parse_redis_url(url) - self.pub = None - self.sub = None + self.redis = aioredis.from_url(url, **redis_options) + self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True) super().__init__(channel=channel, write_only=write_only, logger=logger) async def _publish(self, data): retry = True while True: try: - if self.pub is None: - self.pub = await aioredis.create_redis( - (self.host, self.port), db=self.db, - password=self.password, ssl=self.ssl - ) - return await self.pub.publish(self.channel, - pickle.dumps(data)) - except (aioredis.RedisError, OSError): + return await self.redis.publish(self.channel, pickle.dumps(data)) + except redis.exceptions.RedisError: if retry: - self._get_logger().error('Cannot publish to redis... ' - 'retrying') - self.pub = None + self._get_logger().error('Cannot publish to redis... retrying') retry = False else: - self._get_logger().error('Cannot publish to redis... ' - 'giving up') + self._get_logger().error('Cannot publish to redis... giving up') break async def _listen(self): retry_sleep = 1 while True: try: - if self.sub is None: - self.sub = await aioredis.create_redis( - (self.host, self.port), db=self.db, - password=self.password, ssl=self.ssl - ) - self.ch = (await self.sub.subscribe(self.channel))[0] + await self.pubsub.subscribe(self.channel) retry_sleep = 1 - return await self.ch.get() - except (aioredis.RedisError, OSError): + async for message in self.pubsub.listen(): + yield message['data'] + except aioredis.exceptions.RedisError: self._get_logger().error('Cannot receive from redis... ' - 'retrying in ' - '{} secs'.format(retry_sleep)) - self.sub = None + 'retrying in {} secs'.format(retry_sleep)) await asyncio.sleep(retry_sleep) retry_sleep *= 2 if retry_sleep > 60: retry_sleep = 60 + await self.pubsub.unsubscribe(self.channel) diff --git a/tests/asyncio/test_asyncio_pubsub_manager.py b/tests/asyncio/test_asyncio_pubsub_manager.py index 5cefec8..8bce2a1 100644 --- a/tests/asyncio/test_asyncio_pubsub_manager.py +++ b/tests/asyncio/test_asyncio_pubsub_manager.py @@ -429,7 +429,9 @@ class TestAsyncPubSubManager(unittest.TestCase): yield 'bad json' yield b'bad pickled' - self.pm._listen = AsyncMock(side_effect=list(messages())) + m = mock.MagicMock() + m.__aiter__.return_value = messages() + self.pm._listen = mock.MagicMock(side_effect=[m]) try: _run(self.pm._thread()) except StopIteration: diff --git a/tests/asyncio/test_asyncio_redis_manager.py b/tests/asyncio/test_asyncio_redis_manager.py index a8cf7d8..8042679 100644 --- a/tests/asyncio/test_asyncio_redis_manager.py +++ b/tests/asyncio/test_asyncio_redis_manager.py @@ -7,6 +7,7 @@ from socketio import asyncio_redis_manager @unittest.skipIf(sys.version_info < (3, 5), 'only for Python 3.5+') +@unittest.skip("Deprecated") class TestAsyncRedisManager(unittest.TestCase): def test_default_url(self): assert asyncio_redis_manager._parse_redis_url('redis://') == (