@ -1,6 +1,8 @@
import asyncio
import pytest
import redis
import valkey
import types
from socketio import async_redis_manager
from socketio . async_redis_manager import AsyncRedisManager
@ -105,3 +107,108 @@ class TestAsyncRedisManager:
assert isinstance ( c . redis , valkey . asyncio . Valkey )
async_redis_manager . aioredis = saved_redis
class _FakePubSub :
def __init__ ( self , chan_bytes , script ) :
self . _chan = chan_bytes
self . _script = list ( script )
self . _unsubscribed = False
async def subscribe ( self , channel ) :
return True
async def unsubscribe ( self , channel ) :
self . _unsubscribed = True
return True
async def listen ( self ) :
while self . _script :
step = self . _script . pop ( 0 )
if step == " timeout " :
raise TimeoutError ( " simulated timeout " )
if step == " msg " :
yield {
" type " : " message " ,
" channel " : self . _chan ,
" data " : b " ok " ,
}
while True :
await asyncio . sleep ( 3600 )
class _FakeRedis :
def __init__ ( self , * * opts ) :
self . _opts = opts
self . _scripts = [ ]
@classmethod
def from_url ( cls , url , * * kwargs ) :
obj = cls ( * * kwargs )
return obj
def pubsub ( self , ignore_subscribe_messages = True ) :
script = self . _scripts . pop ( 0 ) if self . _scripts else [ " msg " ]
return _FakePubSub ( b " socketio " , script )
class _FakeSentinel :
class Sentinel :
def __init__ ( self , * a , * * kw ) :
pass
def master_for ( self , * a , * * kw ) :
return _FakeRedis ( )
def _fake_valkey_module ( ) :
mod = types . SimpleNamespace ( )
mod . __name__ = " valkey.asyncio "
mod . Redis = _FakeRedis
mod . sentinel = _FakeSentinel
return mod
class _TestManager ( AsyncRedisManager ) :
""" AsyncRedisManager that uses our fake ' valkey ' module and scripts. """
def __init__ ( self , scripts , * * kw ) :
# scripts is a list of lists, e.g. [["timeout"], ["msg"]]
self . _scripts = scripts
super ( ) . __init__ ( url = " valkey://localhost/0 " , * * kw )
def _get_redis_module_and_error ( self ) :
fake = _fake_valkey_module ( )
return fake , RuntimeError
def _redis_connect ( self ) :
module , _ = self . _get_redis_module_and_error ( )
self . redis = module . Redis . from_url (
self . redis_url , * * ( self . redis_options or { } )
)
self . redis . _scripts = self . _scripts
self . pubsub = self . redis . pubsub ( ignore_subscribe_messages = True )
@pytest . mark . asyncio
async def test_listen_reconnects_after_timeout_and_yields ( ) :
""" First TimeoutError -> reconnect + resubscribe -> yield next message. """
mgr = _TestManager ( scripts = [ [ " timeout " ] , [ " msg " ] ] )
agen = mgr . _listen ( )
got = await asyncio . wait_for ( agen . __anext__ ( ) , timeout = 2.5 )
assert got == b " ok "
await agen . aclose ( )
assert mgr . pubsub . _unsubscribed is True
@pytest . mark . asyncio
async def test_listen_aclose_unsubscribes ( ) :
""" Closing the async generator must unsubscribe the pub/sub. """
mgr = _TestManager ( scripts = [ [ " msg " ] ] )
agen = mgr . _listen ( )
_ = await asyncio . wait_for ( agen . __anext__ ( ) , timeout = 1.0 )
await agen . aclose ( )
assert mgr . pubsub . _unsubscribed is True