diff --git a/src/socketio/async_redis_manager.py b/src/socketio/async_redis_manager.py index b8ac4a0..925c4c5 100644 --- a/src/socketio/async_redis_manager.py +++ b/src/socketio/async_redis_manager.py @@ -1,4 +1,5 @@ import asyncio +import contextlib from urllib.parse import urlparse try: @@ -133,32 +134,42 @@ class AsyncRedisManager(AsyncPubSubManager): async def _redis_listen_with_retries(self): # pragma: no cover retry_sleep = 1 - connect = False - _, error = self._get_redis_module_and_error() - while True: - try: - if connect: - self._redis_connect() - await self.pubsub.subscribe(self.channel) - retry_sleep = 1 - async for message in self.pubsub.listen(): - yield message - except error as exc: - self._get_logger().error('Cannot receive from redis... ' - 'retrying in ' - f'{retry_sleep} secs', - extra={"redis_exception": str(exc)}) - connect = True - await asyncio.sleep(retry_sleep) - retry_sleep *= 2 - if retry_sleep > 60: - retry_sleep = 60 + connect = True + _, BackendError = self._get_redis_module_and_error() + try: + while True: + try: + if connect: + self._redis_connect() + await self.pubsub.subscribe(self.channel) + retry_sleep = 1 + connect = False + async for message in self.pubsub.listen(): + yield message + except (BackendError, OSError, TimeoutError) as exc: + self._get_logger().error( + 'Cannot receive from redis... ' + 'retrying in ' + f'{retry_sleep} secs', + extra={"redis_exception": str(exc)}) + connect = True + await asyncio.sleep(retry_sleep) + retry_sleep *= 2 + if retry_sleep > 60: + retry_sleep = 60 + except asyncio.CancelledError: + raise + finally: + with contextlib.suppress(Exception): + await self.pubsub.unsubscribe(self.channel) async def _listen(self): # pragma: no cover channel = self.channel.encode('utf-8') - await self.pubsub.subscribe(self.channel) - async for message in self._redis_listen_with_retries(): - if message['channel'] == channel and \ - message['type'] == 'message' and 'data' in message: - yield message['data'] - await self.pubsub.unsubscribe(self.channel) + try: + async for message in self._redis_listen_with_retries(): + if message['channel'] == channel and \ + message['type'] == 'message' and 'data' in message: + yield message['data'] + finally: + with contextlib.suppress(Exception): + await self.pubsub.unsubscribe(self.channel) diff --git a/tests/async/test_redis_manager.py b/tests/async/test_redis_manager.py index 01c0c37..4c03576 100644 --- a/tests/async/test_redis_manager.py +++ b/tests/async/test_redis_manager.py @@ -1,6 +1,8 @@ +import asyncio import pytest import redis import valkey +import types from socketio import async_redis_manager from socketio.async_redis_manager import AsyncRedisManager @@ -105,3 +107,108 @@ class TestAsyncRedisManager: assert isinstance(c.redis, valkey.asyncio.Valkey) async_redis_manager.aioredis = saved_redis + + +class _FakePubSub: + def __init__(self, chan_bytes, script): + self._chan = chan_bytes + self._script = list(script) + self._unsubscribed = False + + async def subscribe(self, channel): + return True + + async def unsubscribe(self, channel): + self._unsubscribed = True + return True + + async def listen(self): + while self._script: + step = self._script.pop(0) + if step == "timeout": + raise TimeoutError("simulated timeout") + if step == "msg": + yield { + "type": "message", + "channel": self._chan, + "data": b"ok", + } + while True: + await asyncio.sleep(3600) + + +class _FakeRedis: + def __init__(self, **opts): + self._opts = opts + self._scripts = [] + + @classmethod + def from_url(cls, url, **kwargs): + obj = cls(**kwargs) + return obj + + def pubsub(self, ignore_subscribe_messages=True): + script = self._scripts.pop(0) if self._scripts else ["msg"] + return _FakePubSub(b"socketio", script) + + +class _FakeSentinel: + class Sentinel: + def __init__(self, *a, **kw): + pass + + def master_for(self, *a, **kw): + return _FakeRedis() + + +def _fake_valkey_module(): + mod = types.SimpleNamespace() + mod.__name__ = "valkey.asyncio" + mod.Redis = _FakeRedis + mod.sentinel = _FakeSentinel + return mod + + +class _TestManager(AsyncRedisManager): + """AsyncRedisManager that uses our fake 'valkey' module and scripts.""" + + def __init__(self, scripts, **kw): + # scripts is a list of lists, e.g. [["timeout"], ["msg"]] + self._scripts = scripts + super().__init__(url="valkey://localhost/0", **kw) + + def _get_redis_module_and_error(self): + fake = _fake_valkey_module() + return fake, RuntimeError + + def _redis_connect(self): + module, _ = self._get_redis_module_and_error() + self.redis = module.Redis.from_url( + self.redis_url, **(self.redis_options or {}) + ) + self.redis._scripts = self._scripts + self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True) + + +@pytest.mark.asyncio +async def test_listen_reconnects_after_timeout_and_yields(): + """First TimeoutError -> reconnect + resubscribe -> yield next message.""" + mgr = _TestManager(scripts=[["timeout"], ["msg"]]) + agen = mgr._listen() + + got = await asyncio.wait_for(agen.__anext__(), timeout=2.5) + assert got == b"ok" + + await agen.aclose() + assert mgr.pubsub._unsubscribed is True + + +@pytest.mark.asyncio +async def test_listen_aclose_unsubscribes(): + """Closing the async generator must unsubscribe the pub/sub.""" + mgr = _TestManager(scripts=[["msg"]]) + agen = mgr._listen() + + _ = await asyncio.wait_for(agen.__anext__(), timeout=1.0) + await agen.aclose() + assert mgr.pubsub._unsubscribed is True