From 0b25ff42b8927ac881be7c8ebe1785819bc4c35e Mon Sep 17 00:00:00 2001 From: Dylan Anthony <43723790+dbanty@users.noreply.github.com> Date: Sun, 21 Jul 2019 06:47:03 -0400 Subject: [PATCH] Added rediss:// URL scheme to AsyncRedisManager (#319) * Added rediss:// URL scheme to AsyncRedisManager * Obeyed flake8 --- socketio/asyncio_redis_manager.py | 18 ++++++++++++------ tests/asyncio/test_asyncio_redis_manager.py | 20 +++++++++++++------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/socketio/asyncio_redis_manager.py b/socketio/asyncio_redis_manager.py index 2265937..21499c2 100644 --- a/socketio/asyncio_redis_manager.py +++ b/socketio/asyncio_redis_manager.py @@ -12,8 +12,9 @@ from .asyncio_pubsub_manager import AsyncPubSubManager def _parse_redis_url(url): p = urlparse(url) - if p.scheme != 'redis': + if p.scheme not in {'redis', 'rediss'}: raise ValueError('Invalid redis url') + ssl = p.scheme == 'rediss' host = p.hostname or 'localhost' port = p.port or 6379 password = p.password @@ -21,7 +22,7 @@ def _parse_redis_url(url): db = int(p.path[1:]) else: db = 0 - return host, port, password, db + return host, port, password, db, ssl class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover @@ -39,7 +40,8 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover 'redis://hostname:port/0')) :param url: The connection URL for the Redis server. For a default Redis - store running on the same host, use ``redis://``. + store running on the same host, use ``redis://``. To use an + SSL connection, use ``rediss://``. :param channel: The channel name on which the server sends and receives notifications. Must be the same in all the servers. :param write_only: If set ot ``True``, only initialize to emit events. The @@ -54,7 +56,9 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover raise RuntimeError('Redis package is not installed ' '(Run "pip install aioredis" in your ' 'virtualenv).') - self.host, self.port, self.password, self.db = _parse_redis_url(url) + ( + self.host, self.port, self.password, self.db, self.ssl + ) = _parse_redis_url(url) self.pub = None self.sub = None super().__init__(channel=channel, write_only=write_only, logger=logger) @@ -66,7 +70,8 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover if self.pub is None: self.pub = await aioredis.create_redis( (self.host, self.port), db=self.db, - password=self.password) + password=self.password, ssl=self.ssl + ) return await self.pub.publish(self.channel, pickle.dumps(data)) except (aioredis.RedisError, OSError): @@ -87,7 +92,8 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover if self.sub is None: self.sub = await aioredis.create_redis( (self.host, self.port), db=self.db, - password=self.password) + password=self.password, ssl=self.ssl + ) self.ch = (await self.sub.subscribe(self.channel))[0] return await self.ch.get() except (aioredis.RedisError, OSError): diff --git a/tests/asyncio/test_asyncio_redis_manager.py b/tests/asyncio/test_asyncio_redis_manager.py index 02c12d6..cbcaf6a 100644 --- a/tests/asyncio/test_asyncio_redis_manager.py +++ b/tests/asyncio/test_asyncio_redis_manager.py @@ -8,37 +8,37 @@ from socketio import asyncio_redis_manager class TestAsyncRedisManager(unittest.TestCase): def test_default_url(self): self.assertEqual(asyncio_redis_manager._parse_redis_url('redis://'), - ('localhost', 6379, None, 0)) + ('localhost', 6379, None, 0, False)) def test_only_host_url(self): self.assertEqual( asyncio_redis_manager._parse_redis_url('redis://redis.host'), - ('redis.host', 6379, None, 0)) + ('redis.host', 6379, None, 0, False)) def test_no_db_url(self): self.assertEqual( asyncio_redis_manager._parse_redis_url('redis://redis.host:123/1'), - ('redis.host', 123, None, 1)) + ('redis.host', 123, None, 1, False)) def test_no_port_url(self): self.assertEqual( asyncio_redis_manager._parse_redis_url('redis://redis.host/1'), - ('redis.host', 6379, None, 1)) + ('redis.host', 6379, None, 1, False)) def test_password(self): self.assertEqual( asyncio_redis_manager._parse_redis_url('redis://:pw@redis.host/1'), - ('redis.host', 6379, 'pw', 1)) + ('redis.host', 6379, 'pw', 1, False)) def test_no_host_url(self): self.assertEqual( asyncio_redis_manager._parse_redis_url('redis://:123/1'), - ('localhost', 123, None, 1)) + ('localhost', 123, None, 1, False)) def test_no_host_password_url(self): self.assertEqual( asyncio_redis_manager._parse_redis_url('redis://:pw@:123/1'), - ('localhost', 123, 'pw', 1)) + ('localhost', 123, 'pw', 1, False)) def test_bad_port_url(self): self.assertRaises(ValueError, asyncio_redis_manager._parse_redis_url, @@ -51,3 +51,9 @@ class TestAsyncRedisManager(unittest.TestCase): def test_bad_scheme_url(self): self.assertRaises(ValueError, asyncio_redis_manager._parse_redis_url, 'http://redis.host:123/1') + + def test_ssl_scheme(self): + self.assertEqual( + asyncio_redis_manager._parse_redis_url('rediss://'), + ('localhost', 6379, None, 0, True) + )