diff --git a/.gitignore b/.gitignore index 01d8b9a..5899acd 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,5 @@ htmlcov *.swp node_modules +.venv/ +.python-version diff --git a/src/socketio/async_redis_manager.py b/src/socketio/async_redis_manager.py index b8ac4a0..99f955f 100644 --- a/src/socketio/async_redis_manager.py +++ b/src/socketio/async_redis_manager.py @@ -1,4 +1,6 @@ import asyncio +import contextlib +import random from urllib.parse import urlparse try: @@ -51,13 +53,21 @@ class AsyncRedisManager(AsyncPubSubManager): :param redis_options: additional keyword arguments to be passed to ``Redis.from_url()`` or ``Sentinel()``. """ - name = 'aioredis' - def __init__(self, url='redis://localhost:6379/0', channel='socketio', - write_only=False, logger=None, redis_options=None): - if aioredis and \ - not hasattr(aioredis.Redis, 'from_url'): # pragma: no cover - raise RuntimeError('Version 2 of aioredis package is required.') + name = "aioredis" + + def __init__( + self, + url="redis://localhost:6379/0", + channel="socketio", + write_only=False, + logger=None, + redis_options=None, + ): + if aioredis and not hasattr( + aioredis.Redis, "from_url" + ): # pragma: no cover + raise RuntimeError("Version 2 of aioredis package is required.") super().__init__(channel=channel, write_only=write_only, logger=logger) self.redis_url = url self.redis_options = redis_options or {} @@ -65,46 +75,63 @@ class AsyncRedisManager(AsyncPubSubManager): def _get_redis_module_and_error(self): parsed_url = urlparse(self.redis_url) - scheme = parsed_url.scheme.split('+', 1)[0].lower() - if scheme in ['redis', 'rediss']: + scheme = parsed_url.scheme.split("+", 1)[0].lower() + if scheme in ["redis", "rediss"]: if aioredis is None or RedisError is None: - raise RuntimeError('Redis package is not installed ' - '(Run "pip install redis" ' - 'in your virtualenv).') + raise RuntimeError( + "Redis package is not installed " + '(Run "pip install redis" ' + "in your virtualenv)." + ) return aioredis, RedisError - if scheme in ['valkey', 'valkeys']: + if scheme in ["valkey", "valkeys"]: if aiovalkey is None or ValkeyError is None: - raise RuntimeError('Valkey package is not installed ' - '(Run "pip install valkey" ' - 'in your virtualenv).') + raise RuntimeError( + "Valkey package is not installed " + '(Run "pip install valkey" ' + "in your virtualenv)." + ) return aiovalkey, ValkeyError - if scheme == 'unix': - if aioredis is None or RedisError is None: - if aiovalkey is None or ValkeyError is None: - raise RuntimeError('Redis package is not installed ' - '(Run "pip install redis" ' - 'or "pip install valkey" ' - 'in your virtualenv).') - else: - return aiovalkey, ValkeyError - else: + if scheme == "unix": + if aioredis is not None and RedisError is not None: return aioredis, RedisError - error_msg = f'Unsupported Redis URL scheme: {scheme}' - raise ValueError(error_msg) + if aiovalkey is not None and ValkeyError is not None: + return aiovalkey, ValkeyError + raise RuntimeError('Install "redis" or "valkey" package.') + raise ValueError(f"Unsupported Redis URL scheme: {scheme}") def _redis_connect(self): module, _ = self._get_redis_module_and_error() parsed_url = urlparse(self.redis_url) + + # Backend-aware pubsub socket defaults. Caller can override. + is_valkey = module.__name__.startswith("valkey.") + pubsub_defaults = { + "decode_responses": False, + "socket_keepalive": True, + "retry_on_timeout": False, + } + if is_valkey: + pubsub_defaults.update( + { + "socket_timeout": None, # block indefinitely + "socket_connect_timeout": 3, # fail fast on bad host + "health_check_interval": 0, # no PINGs on pubsub socket + } + ) + + kwargs = {**pubsub_defaults, **(self.redis_options or {})} + if parsed_url.scheme in {"redis+sentinel", "valkey+sentinel"}: - sentinels, service_name, connection_kwargs = \ + sentinels, service_name, connection_kwargs = ( parse_redis_sentinel_url(self.redis_url) - kwargs = self.redis_options - kwargs.update(connection_kwargs) - sentinel = module.sentinel.Sentinel(sentinels, **kwargs) + ) + connection_kwargs.update(kwargs) + sentinel = module.sentinel.Sentinel(sentinels, **connection_kwargs) self.redis = sentinel.master_for(service_name or self.channel) else: - self.redis = module.Redis.from_url(self.redis_url, - **self.redis_options) + self.redis = module.Redis.from_url(self.redis_url, **kwargs) + self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True) async def _publish(self, data): # pragma: no cover @@ -132,33 +159,70 @@ class AsyncRedisManager(AsyncPubSubManager): break async def _redis_listen_with_retries(self): # pragma: no cover - retry_sleep = 1 - connect = False - _, error = self._get_redis_module_and_error() + """ + Stream pub/sub messages forever; auto-reconnect on transient errors. + - Works with both Redis and Valkey (we detect the right error class). + - Backoff: 1s -> 2 -> 4 ... capped at 60s, with ±20% jitter. + - Any successfully received message resets the backoff to 1s. + - On shutdown (CancelledError) we try to unsubscribe cleanly. + """ + backoff = 1.0 + max_backoff = 60.0 + connect_needed = True + _, BackendError = self._get_redis_module_and_error() + backend_name = getattr(self, "name", "redis") + while True: try: - if connect: + if connect_needed: self._redis_connect() await self.pubsub.subscribe(self.channel) - retry_sleep = 1 + connect_needed = False + async for message in self.pubsub.listen(): + backoff = 1.0 yield message - except error as exc: - self._get_logger().error('Cannot receive from redis... ' - 'retrying in ' - f'{retry_sleep} secs', - extra={"redis_exception": str(exc)}) - connect = True - await asyncio.sleep(retry_sleep) - retry_sleep *= 2 - if retry_sleep > 60: - retry_sleep = 60 + + except asyncio.CancelledError: + with contextlib.suppress(Exception): + await self.pubsub.unsubscribe(self.channel) + raise + + except (BackendError, OSError, TimeoutError) as exc: + self._get_logger().error( + "%s pub/sub listen error; reconnecting in %.1fs", + backend_name, + backoff, + extra={"backend_exception": str(exc)}, + ) + connect_needed = True + + jitter = backoff * (random.random() * 0.4 - 0.2) + await asyncio.sleep(max(0.0, backoff + jitter)) + backoff = min(backoff * 2, max_backoff) + + @staticmethod + def _channel_matches(expected: bytes, msg_channel) -> bool: + if isinstance(msg_channel, bytes): + return msg_channel == expected + if isinstance(msg_channel, str): + try: + return msg_channel.encode("utf-8") == expected + except Exception: + return False + return False async def _listen(self): # pragma: no cover - channel = self.channel.encode('utf-8') + """Continuously listen on the pub/sub channel and yield messages.""" + expected = self.channel.encode("utf-8") await self.pubsub.subscribe(self.channel) + async for message in self._redis_listen_with_retries(): - if message['channel'] == channel and \ - message['type'] == 'message' and 'data' in message: - yield message['data'] + if ( + message.get("type") == "message" + and "data" in message + and self._channel_matches(expected, message.get("channel")) + ): + yield message["data"] + await self.pubsub.unsubscribe(self.channel) diff --git a/tests/async/test_redis_manager.py b/tests/async/test_redis_manager.py index 01c0c37..6d41a14 100644 --- a/tests/async/test_redis_manager.py +++ b/tests/async/test_redis_manager.py @@ -105,3 +105,38 @@ class TestAsyncRedisManager: assert isinstance(c.redis, valkey.asyncio.Valkey) async_redis_manager.aioredis = saved_redis + + def test_channel_matches(self): + expected = b"socketio" + + # bytes: matching and non-matching + assert AsyncRedisManager._channel_matches( + expected, b"socketio" + ) is True + assert AsyncRedisManager._channel_matches( + expected, b"other" + ) is False + + # str: matching and non-matching + assert AsyncRedisManager._channel_matches( + expected, "socketio" + ) is True + assert AsyncRedisManager._channel_matches( + expected, "other" + ) is False + + # str: encoding raises + class BadStr(str): + def encode(self, *_, **__): + raise UnicodeEncodeError( + "utf-8", "x", 0, 1, "boom" + ) + + assert AsyncRedisManager._channel_matches( + expected, BadStr("foo") + ) is False + + # non-string/non-bytes (int) + assert AsyncRedisManager._channel_matches( + expected, 123 + ) is False