Browse Source

async_redis_manager: add Valkey support + robust listen loop; tests for valkey urls and channel matching

pull/1531/head
youssef khaya 7 months ago
parent
commit
81c3988252
  1. 2
      .gitignore
  2. 168
      src/socketio/async_redis_manager.py
  3. 35
      tests/async/test_redis_manager.py

2
.gitignore

@ -45,3 +45,5 @@ htmlcov
*.swp *.swp
node_modules node_modules
.venv/
.python-version

168
src/socketio/async_redis_manager.py

@ -1,4 +1,6 @@
import asyncio import asyncio
import contextlib
import random
from urllib.parse import urlparse from urllib.parse import urlparse
try: try:
@ -51,13 +53,21 @@ class AsyncRedisManager(AsyncPubSubManager):
:param redis_options: additional keyword arguments to be passed to :param redis_options: additional keyword arguments to be passed to
``Redis.from_url()`` or ``Sentinel()``. ``Redis.from_url()`` or ``Sentinel()``.
""" """
name = 'aioredis'
def __init__(self, url='redis://localhost:6379/0', channel='socketio', name = "aioredis"
write_only=False, logger=None, redis_options=None):
if aioredis and \ def __init__(
not hasattr(aioredis.Redis, 'from_url'): # pragma: no cover self,
raise RuntimeError('Version 2 of aioredis package is required.') url="redis://localhost:6379/0",
channel="socketio",
write_only=False,
logger=None,
redis_options=None,
):
if aioredis and not hasattr(
aioredis.Redis, "from_url"
): # pragma: no cover
raise RuntimeError("Version 2 of aioredis package is required.")
super().__init__(channel=channel, write_only=write_only, logger=logger) super().__init__(channel=channel, write_only=write_only, logger=logger)
self.redis_url = url self.redis_url = url
self.redis_options = redis_options or {} self.redis_options = redis_options or {}
@ -65,46 +75,63 @@ class AsyncRedisManager(AsyncPubSubManager):
def _get_redis_module_and_error(self): def _get_redis_module_and_error(self):
parsed_url = urlparse(self.redis_url) parsed_url = urlparse(self.redis_url)
scheme = parsed_url.scheme.split('+', 1)[0].lower() scheme = parsed_url.scheme.split("+", 1)[0].lower()
if scheme in ['redis', 'rediss']: if scheme in ["redis", "rediss"]:
if aioredis is None or RedisError is None: if aioredis is None or RedisError is None:
raise RuntimeError('Redis package is not installed ' raise RuntimeError(
'(Run "pip install redis" ' "Redis package is not installed "
'in your virtualenv).') '(Run "pip install redis" '
"in your virtualenv)."
)
return aioredis, RedisError return aioredis, RedisError
if scheme in ['valkey', 'valkeys']: if scheme in ["valkey", "valkeys"]:
if aiovalkey is None or ValkeyError is None: if aiovalkey is None or ValkeyError is None:
raise RuntimeError('Valkey package is not installed ' raise RuntimeError(
'(Run "pip install valkey" ' "Valkey package is not installed "
'in your virtualenv).') '(Run "pip install valkey" '
"in your virtualenv)."
)
return aiovalkey, ValkeyError return aiovalkey, ValkeyError
if scheme == 'unix': if scheme == "unix":
if aioredis is None or RedisError is None: if aioredis is not None and RedisError is not None:
if aiovalkey is None or ValkeyError is None:
raise RuntimeError('Redis package is not installed '
'(Run "pip install redis" '
'or "pip install valkey" '
'in your virtualenv).')
else:
return aiovalkey, ValkeyError
else:
return aioredis, RedisError return aioredis, RedisError
error_msg = f'Unsupported Redis URL scheme: {scheme}' if aiovalkey is not None and ValkeyError is not None:
raise ValueError(error_msg) return aiovalkey, ValkeyError
raise RuntimeError('Install "redis" or "valkey" package.')
raise ValueError(f"Unsupported Redis URL scheme: {scheme}")
def _redis_connect(self): def _redis_connect(self):
module, _ = self._get_redis_module_and_error() module, _ = self._get_redis_module_and_error()
parsed_url = urlparse(self.redis_url) parsed_url = urlparse(self.redis_url)
# Backend-aware pubsub socket defaults. Caller can override.
is_valkey = module.__name__.startswith("valkey.")
pubsub_defaults = {
"decode_responses": False,
"socket_keepalive": True,
"retry_on_timeout": False,
}
if is_valkey:
pubsub_defaults.update(
{
"socket_timeout": None, # block indefinitely
"socket_connect_timeout": 3, # fail fast on bad host
"health_check_interval": 0, # no PINGs on pubsub socket
}
)
kwargs = {**pubsub_defaults, **(self.redis_options or {})}
if parsed_url.scheme in {"redis+sentinel", "valkey+sentinel"}: if parsed_url.scheme in {"redis+sentinel", "valkey+sentinel"}:
sentinels, service_name, connection_kwargs = \ sentinels, service_name, connection_kwargs = (
parse_redis_sentinel_url(self.redis_url) parse_redis_sentinel_url(self.redis_url)
kwargs = self.redis_options )
kwargs.update(connection_kwargs) connection_kwargs.update(kwargs)
sentinel = module.sentinel.Sentinel(sentinels, **kwargs) sentinel = module.sentinel.Sentinel(sentinels, **connection_kwargs)
self.redis = sentinel.master_for(service_name or self.channel) self.redis = sentinel.master_for(service_name or self.channel)
else: else:
self.redis = module.Redis.from_url(self.redis_url, self.redis = module.Redis.from_url(self.redis_url, **kwargs)
**self.redis_options)
self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True) self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True)
async def _publish(self, data): # pragma: no cover async def _publish(self, data): # pragma: no cover
@ -132,33 +159,70 @@ class AsyncRedisManager(AsyncPubSubManager):
break break
async def _redis_listen_with_retries(self): # pragma: no cover async def _redis_listen_with_retries(self): # pragma: no cover
retry_sleep = 1 """
connect = False Stream pub/sub messages forever; auto-reconnect on transient errors.
_, error = self._get_redis_module_and_error() - Works with both Redis and Valkey (we detect the right error class).
- Backoff: 1s -> 2 -> 4 ... capped at 60s, with ±20% jitter.
- Any successfully received message resets the backoff to 1s.
- On shutdown (CancelledError) we try to unsubscribe cleanly.
"""
backoff = 1.0
max_backoff = 60.0
connect_needed = True
_, BackendError = self._get_redis_module_and_error()
backend_name = getattr(self, "name", "redis")
while True: while True:
try: try:
if connect: if connect_needed:
self._redis_connect() self._redis_connect()
await self.pubsub.subscribe(self.channel) await self.pubsub.subscribe(self.channel)
retry_sleep = 1 connect_needed = False
async for message in self.pubsub.listen(): async for message in self.pubsub.listen():
backoff = 1.0
yield message yield message
except error as exc:
self._get_logger().error('Cannot receive from redis... ' except asyncio.CancelledError:
'retrying in ' with contextlib.suppress(Exception):
f'{retry_sleep} secs', await self.pubsub.unsubscribe(self.channel)
extra={"redis_exception": str(exc)}) raise
connect = True
await asyncio.sleep(retry_sleep) except (BackendError, OSError, TimeoutError) as exc:
retry_sleep *= 2 self._get_logger().error(
if retry_sleep > 60: "%s pub/sub listen error; reconnecting in %.1fs",
retry_sleep = 60 backend_name,
backoff,
extra={"backend_exception": str(exc)},
)
connect_needed = True
jitter = backoff * (random.random() * 0.4 - 0.2)
await asyncio.sleep(max(0.0, backoff + jitter))
backoff = min(backoff * 2, max_backoff)
@staticmethod
def _channel_matches(expected: bytes, msg_channel) -> bool:
if isinstance(msg_channel, bytes):
return msg_channel == expected
if isinstance(msg_channel, str):
try:
return msg_channel.encode("utf-8") == expected
except Exception:
return False
return False
async def _listen(self): # pragma: no cover async def _listen(self): # pragma: no cover
channel = self.channel.encode('utf-8') """Continuously listen on the pub/sub channel and yield messages."""
expected = self.channel.encode("utf-8")
await self.pubsub.subscribe(self.channel) await self.pubsub.subscribe(self.channel)
async for message in self._redis_listen_with_retries(): async for message in self._redis_listen_with_retries():
if message['channel'] == channel and \ if (
message['type'] == 'message' and 'data' in message: message.get("type") == "message"
yield message['data'] and "data" in message
and self._channel_matches(expected, message.get("channel"))
):
yield message["data"]
await self.pubsub.unsubscribe(self.channel) await self.pubsub.unsubscribe(self.channel)

35
tests/async/test_redis_manager.py

@ -105,3 +105,38 @@ class TestAsyncRedisManager:
assert isinstance(c.redis, valkey.asyncio.Valkey) assert isinstance(c.redis, valkey.asyncio.Valkey)
async_redis_manager.aioredis = saved_redis async_redis_manager.aioredis = saved_redis
def test_channel_matches(self):
expected = b"socketio"
# bytes: matching and non-matching
assert AsyncRedisManager._channel_matches(
expected, b"socketio"
) is True
assert AsyncRedisManager._channel_matches(
expected, b"other"
) is False
# str: matching and non-matching
assert AsyncRedisManager._channel_matches(
expected, "socketio"
) is True
assert AsyncRedisManager._channel_matches(
expected, "other"
) is False
# str: encoding raises
class BadStr(str):
def encode(self, *_, **__):
raise UnicodeEncodeError(
"utf-8", "x", 0, 1, "boom"
)
assert AsyncRedisManager._channel_matches(
expected, BadStr("foo")
) is False
# non-string/non-bytes (int)
assert AsyncRedisManager._channel_matches(
expected, 123
) is False

Loading…
Cancel
Save