From af6818453e3312ef623eed61326d751abf1306d6 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Fri, 31 Mar 2023 15:47:07 +0100 Subject: [PATCH] Make async client manager connect a coroutine --- src/socketio/async_manager.py | 7 ++ src/socketio/async_server.py | 2 +- tests/async/test_manager.py | 108 ++++++++++++++--------------- tests/async/test_pubsub_manager.py | 6 +- tests/async/test_server.py | 24 +++---- 5 files changed, 77 insertions(+), 70 deletions(-) diff --git a/src/socketio/async_manager.py b/src/socketio/async_manager.py index 6646376..dcf79cf 100644 --- a/src/socketio/async_manager.py +++ b/src/socketio/async_manager.py @@ -62,6 +62,13 @@ class AsyncManager(BaseManager): return await asyncio.wait(tasks) + async def connect(self, eio_sid, namespace): + """Register a client connection to a namespace. + + Note: this method is a coroutine. + """ + return super().connect(eio_sid, namespace) + async def disconnect(self, sid, namespace, **kwargs): """Disconnect a client. diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index f3bb8f8..89b7c50 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -486,7 +486,7 @@ class AsyncServer(base_server.BaseServer): sid = None if namespace in self.handlers or namespace in self.namespace_handlers \ or self.namespaces == '*' or namespace in self.namespaces: - sid = self.manager.connect(eio_sid, namespace) + sid = await self.manager.connect(eio_sid, namespace) if sid is None: await self._send_packet(eio_sid, self.packet_class( packet.CONNECT_ERROR, data='Unable to connect', diff --git a/tests/async/test_manager.py b/tests/async/test_manager.py index 7cfb46c..306734d 100644 --- a/tests/async/test_manager.py +++ b/tests/async/test_manager.py @@ -27,7 +27,7 @@ class TestAsyncManager(unittest.TestCase): self.bm.initialize() def test_connect(self): - sid = self.bm.connect('123', '/foo') + sid = _run(self.bm.connect('123', '/foo')) assert None in self.bm.rooms['/foo'] assert sid in self.bm.rooms['/foo'] assert sid in self.bm.rooms['/foo'][None] @@ -37,8 +37,8 @@ class TestAsyncManager(unittest.TestCase): assert self.bm.sid_from_eio_sid('123', '/foo') == sid def test_pre_disconnect(self): - sid1 = self.bm.connect('123', '/foo') - sid2 = self.bm.connect('456', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) + sid2 = _run(self.bm.connect('456', '/foo')) assert self.bm.is_connected(sid1, '/foo') assert self.bm.pre_disconnect(sid1, '/foo') == '123' assert self.bm.pending_disconnect == {'/foo': [sid1]} @@ -52,8 +52,8 @@ class TestAsyncManager(unittest.TestCase): assert self.bm.pending_disconnect == {} def test_disconnect(self): - sid1 = self.bm.connect('123', '/foo') - sid2 = self.bm.connect('456', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) + sid2 = _run(self.bm.connect('456', '/foo')) _run(self.bm.enter_room(sid1, '/foo', 'bar')) _run(self.bm.enter_room(sid2, '/foo', 'baz')) _run(self.bm.disconnect(sid1, '/foo')) @@ -62,10 +62,10 @@ class TestAsyncManager(unittest.TestCase): assert dict(self.bm.rooms['/foo']['baz']) == {sid2: '456'} def test_disconnect_default_namespace(self): - sid1 = self.bm.connect('123', '/') - sid2 = self.bm.connect('123', '/foo') - sid3 = self.bm.connect('456', '/') - sid4 = self.bm.connect('456', '/foo') + sid1 = _run(self.bm.connect('123', '/')) + sid2 = _run(self.bm.connect('123', '/foo')) + sid3 = _run(self.bm.connect('456', '/')) + sid4 = _run(self.bm.connect('456', '/foo')) assert self.bm.is_connected(sid1, '/') assert self.bm.is_connected(sid2, '/foo') assert not self.bm.is_connected(sid2, '/') @@ -81,10 +81,10 @@ class TestAsyncManager(unittest.TestCase): assert dict(self.bm.rooms['/foo'][sid4]) == {sid4: '456'} def test_disconnect_twice(self): - sid1 = self.bm.connect('123', '/') - sid2 = self.bm.connect('123', '/foo') - sid3 = self.bm.connect('456', '/') - sid4 = self.bm.connect('456', '/foo') + sid1 = _run(self.bm.connect('123', '/')) + sid2 = _run(self.bm.connect('123', '/foo')) + sid3 = _run(self.bm.connect('456', '/')) + sid4 = _run(self.bm.connect('456', '/foo')) _run(self.bm.disconnect(sid1, '/')) _run(self.bm.disconnect(sid2, '/foo')) _run(self.bm.disconnect(sid1, '/')) @@ -95,8 +95,8 @@ class TestAsyncManager(unittest.TestCase): assert dict(self.bm.rooms['/foo'][sid4]) == {sid4: '456'} def test_disconnect_all(self): - sid1 = self.bm.connect('123', '/foo') - sid2 = self.bm.connect('456', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) + sid2 = _run(self.bm.connect('456', '/foo')) _run(self.bm.enter_room(sid1, '/foo', 'bar')) _run(self.bm.enter_room(sid2, '/foo', 'baz')) _run(self.bm.disconnect(sid1, '/foo')) @@ -104,9 +104,9 @@ class TestAsyncManager(unittest.TestCase): assert self.bm.rooms == {} def test_disconnect_with_callbacks(self): - sid1 = self.bm.connect('123', '/') - sid2 = self.bm.connect('123', '/foo') - sid3 = self.bm.connect('456', '/foo') + sid1 = _run(self.bm.connect('123', '/')) + sid2 = _run(self.bm.connect('123', '/foo')) + sid3 = _run(self.bm.connect('456', '/foo')) self.bm._generate_ack_id(sid1, 'f') self.bm._generate_ack_id(sid2, 'g') self.bm._generate_ack_id(sid3, 'h') @@ -117,8 +117,8 @@ class TestAsyncManager(unittest.TestCase): assert sid3 in self.bm.callbacks def test_trigger_sync_callback(self): - sid1 = self.bm.connect('123', '/') - sid2 = self.bm.connect('123', '/foo') + sid1 = _run(self.bm.connect('123', '/')) + sid2 = _run(self.bm.connect('123', '/foo')) cb = mock.MagicMock() id1 = self.bm._generate_ack_id(sid1, cb) id2 = self.bm._generate_ack_id(sid2, cb) @@ -129,8 +129,8 @@ class TestAsyncManager(unittest.TestCase): cb.assert_any_call('bar', 'baz') def test_trigger_async_callback(self): - sid1 = self.bm.connect('123', '/') - sid2 = self.bm.connect('123', '/foo') + sid1 = _run(self.bm.connect('123', '/')) + sid2 = _run(self.bm.connect('123', '/foo')) cb = AsyncMock() id1 = self.bm._generate_ack_id(sid1, cb) id2 = self.bm._generate_ack_id(sid2, cb) @@ -141,7 +141,7 @@ class TestAsyncManager(unittest.TestCase): cb.mock.assert_any_call('bar', 'baz') def test_invalid_callback(self): - sid = self.bm.connect('123', '/') + sid = _run(self.bm.connect('123', '/')) cb = mock.MagicMock() id = self.bm._generate_ack_id(sid, cb) @@ -152,17 +152,17 @@ class TestAsyncManager(unittest.TestCase): def test_get_namespaces(self): assert list(self.bm.get_namespaces()) == [] - self.bm.connect('123', '/') - self.bm.connect('123', '/foo') + _run(self.bm.connect('123', '/')) + _run(self.bm.connect('123', '/foo')) namespaces = list(self.bm.get_namespaces()) assert len(namespaces) == 2 assert '/' in namespaces assert '/foo' in namespaces def test_get_participants(self): - sid1 = self.bm.connect('123', '/') - sid2 = self.bm.connect('456', '/') - sid3 = self.bm.connect('789', '/') + sid1 = _run(self.bm.connect('123', '/')) + sid2 = _run(self.bm.connect('456', '/')) + sid3 = _run(self.bm.connect('789', '/')) _run(self.bm.disconnect(sid3, '/')) assert sid3 not in self.bm.rooms['/'][None] participants = list(self.bm.get_participants('/', None)) @@ -172,7 +172,7 @@ class TestAsyncManager(unittest.TestCase): assert (sid3, '789') not in participants def test_leave_invalid_room(self): - sid = self.bm.connect('123', '/foo') + sid = _run(self.bm.connect('123', '/foo')) _run(self.bm.leave_room(sid, '/foo', 'baz')) _run(self.bm.leave_room(sid, '/bar', 'baz')) @@ -181,9 +181,9 @@ class TestAsyncManager(unittest.TestCase): assert [] == rooms def test_close_room(self): - sid = self.bm.connect('123', '/foo') - self.bm.connect('456', '/foo') - self.bm.connect('789', '/foo') + sid = _run(self.bm.connect('123', '/foo')) + _run(self.bm.connect('456', '/foo')) + _run(self.bm.connect('789', '/foo')) _run(self.bm.enter_room(sid, '/foo', 'bar')) _run(self.bm.enter_room(sid, '/foo', 'bar')) _run(self.bm.close_room('bar', '/foo')) @@ -195,7 +195,7 @@ class TestAsyncManager(unittest.TestCase): self.bm.close_room('bar', '/foo') def test_rooms(self): - sid = self.bm.connect('123', '/foo') + sid = _run(self.bm.connect('123', '/foo')) _run(self.bm.enter_room(sid, '/foo', 'bar')) r = self.bm.get_rooms(sid, '/foo') assert len(r) == 2 @@ -203,8 +203,8 @@ class TestAsyncManager(unittest.TestCase): assert 'bar' in r def test_emit_to_sid(self): - sid = self.bm.connect('123', '/foo') - self.bm.connect('456', '/foo') + sid = _run(self.bm.connect('123', '/foo')) + _run(self.bm.connect('456', '/foo')) _run( self.bm.emit( 'my event', {'foo': 'bar'}, namespace='/foo', room=sid @@ -217,11 +217,11 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event",{"foo":"bar"}]' def test_emit_to_room(self): - sid1 = self.bm.connect('123', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) _run(self.bm.enter_room(sid1, '/foo', 'bar')) - sid2 = self.bm.connect('456', '/foo') + sid2 = _run(self.bm.connect('456', '/foo')) _run(self.bm.enter_room(sid2, '/foo', 'bar')) - self.bm.connect('789', '/foo') + _run(self.bm.connect('789', '/foo')) _run( self.bm.emit( 'my event', {'foo': 'bar'}, namespace='/foo', room='bar' @@ -238,12 +238,12 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event",{"foo":"bar"}]' def test_emit_to_rooms(self): - sid1 = self.bm.connect('123', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) _run(self.bm.enter_room(sid1, '/foo', 'bar')) - sid2 = self.bm.connect('456', '/foo') + sid2 = _run(self.bm.connect('456', '/foo')) _run(self.bm.enter_room(sid2, '/foo', 'bar')) _run(self.bm.enter_room(sid2, '/foo', 'baz')) - sid3 = self.bm.connect('789', '/foo') + sid3 = _run(self.bm.connect('789', '/foo')) _run(self.bm.enter_room(sid3, '/foo', 'baz')) _run( self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo', @@ -264,12 +264,12 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event",{"foo":"bar"}]' def test_emit_to_all(self): - sid1 = self.bm.connect('123', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) _run(self.bm.enter_room(sid1, '/foo', 'bar')) - sid2 = self.bm.connect('456', '/foo') + sid2 = _run(self.bm.connect('456', '/foo')) _run(self.bm.enter_room(sid2, '/foo', 'bar')) - self.bm.connect('789', '/foo') - self.bm.connect('abc', '/bar') + _run(self.bm.connect('789', '/foo')) + _run(self.bm.connect('abc', '/bar')) _run(self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo')) assert self.bm.server._send_eio_packet.mock.call_count == 3 assert self.bm.server._send_eio_packet.mock.call_args_list[0][0][0] \ @@ -286,12 +286,12 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event",{"foo":"bar"}]' def test_emit_to_all_skip_one(self): - sid1 = self.bm.connect('123', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) _run(self.bm.enter_room(sid1, '/foo', 'bar')) - sid2 = self.bm.connect('456', '/foo') + sid2 = _run(self.bm.connect('456', '/foo')) _run(self.bm.enter_room(sid2, '/foo', 'bar')) - self.bm.connect('789', '/foo') - self.bm.connect('abc', '/bar') + _run(self.bm.connect('789', '/foo')) + _run(self.bm.connect('abc', '/bar')) _run( self.bm.emit( 'my event', {'foo': 'bar'}, namespace='/foo', skip_sid=sid2 @@ -308,12 +308,12 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event",{"foo":"bar"}]' def test_emit_to_all_skip_two(self): - sid1 = self.bm.connect('123', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) _run(self.bm.enter_room(sid1, '/foo', 'bar')) - sid2 = self.bm.connect('456', '/foo') + sid2 = _run(self.bm.connect('456', '/foo')) _run(self.bm.enter_room(sid2, '/foo', 'bar')) - sid3 = self.bm.connect('789', '/foo') - self.bm.connect('abc', '/bar') + sid3 = _run(self.bm.connect('789', '/foo')) + _run(self.bm.connect('abc', '/bar')) _run( self.bm.emit( 'my event', @@ -329,7 +329,7 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event",{"foo":"bar"}]' def test_emit_with_callback(self): - sid = self.bm.connect('123', '/foo') + sid = _run(self.bm.connect('123', '/foo')) self.bm._generate_ack_id = mock.MagicMock() self.bm._generate_ack_id.return_value = 11 _run( diff --git a/tests/async/test_pubsub_manager.py b/tests/async/test_pubsub_manager.py index 3b4d0a9..da0b86d 100644 --- a/tests/async/test_pubsub_manager.py +++ b/tests/async/test_pubsub_manager.py @@ -145,7 +145,7 @@ class TestAsyncPubSubManager(unittest.TestCase): _run(self.pm.emit('foo', 'bar', callback='cb')) def test_emit_with_ignore_queue(self): - sid = self.pm.connect('123', '/') + sid = _run(self.pm.connect('123', '/')) _run( self.pm.emit( 'foo', 'bar', room=sid, namespace='/', ignore_queue=True @@ -159,7 +159,7 @@ class TestAsyncPubSubManager(unittest.TestCase): assert pkt.encode() == '42["foo","bar"]' def test_can_disconnect(self): - sid = self.pm.connect('123', '/') + sid = _run(self.pm.connect('123', '/')) assert _run(self.pm.can_disconnect(sid, '/')) is True _run(self.pm.can_disconnect(sid, '/foo')) self.pm._publish.mock.assert_called_once_with( @@ -175,7 +175,7 @@ class TestAsyncPubSubManager(unittest.TestCase): ) def test_disconnect_ignore_queue(self): - sid = self.pm.connect('123', '/') + sid = _run(self.pm.connect('123', '/')) self.pm.pre_disconnect(sid, '/') _run(self.pm.disconnect(sid, '/', ignore_queue=True)) self.pm._publish.mock.assert_not_called() diff --git a/tests/async/test_server.py b/tests/async/test_server.py index e324c39..2f84b5f 100644 --- a/tests/async/test_server.py +++ b/tests/async/test_server.py @@ -596,7 +596,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) handler = AsyncMock() catchall_handler = AsyncMock() s.on('msg', handler) @@ -610,7 +610,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_with_namespace(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/foo') + sid = _run(s.manager.connect('123', '/foo')) handler = mock.MagicMock() catchall_handler = mock.MagicMock() s.on('msg', handler, namespace='/foo') @@ -624,7 +624,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_with_disconnected_namespace(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - s.manager.connect('123', '/foo') + _run(s.manager.connect('123', '/foo')) handler = mock.MagicMock() s.on('my message', handler, namespace='/bar') _run(s._handle_eio_message('123', '2/bar,["my message","a","b","c"]')) @@ -633,7 +633,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_binary(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) handler = mock.MagicMock() s.on('my message', handler) _run( @@ -652,7 +652,7 @@ class TestAsyncServer(unittest.TestCase): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) s.manager.trigger_callback = AsyncMock() - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) _run( s._handle_eio_message( '123', @@ -667,7 +667,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_with_ack(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) handler = mock.MagicMock(return_value='foo') s.on('my message', handler) _run(s._handle_eio_message('123', '21000["my message","foo"]')) @@ -679,7 +679,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_unknown_event_with_ack(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - s.manager.connect('123', '/') + _run(s.manager.connect('123', '/')) handler = mock.MagicMock(return_value='foo') s.on('my message', handler) _run(s._handle_eio_message('123', '21000["another message","foo"]')) @@ -688,7 +688,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_with_ack_none(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) handler = mock.MagicMock(return_value=None) s.on('my message', handler) _run(s._handle_eio_message('123', '21000["my message","foo"]')) @@ -698,7 +698,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_with_ack_tuple(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) handler = mock.MagicMock(return_value=(1, '2', True)) s.on('my message', handler) _run(s._handle_eio_message('123', '21000["my message","a","b","c"]')) @@ -710,7 +710,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_with_ack_list(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) handler = mock.MagicMock(return_value=[1, '2', True]) s.on('my message', handler) _run(s._handle_eio_message('123', '21000["my message","a","b","c"]')) @@ -722,7 +722,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_with_ack_binary(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) handler = mock.MagicMock(return_value=b'foo') s.on('my message', handler) _run(s._handle_eio_message('123', '21000["my message","foo"]')) @@ -973,7 +973,7 @@ class TestAsyncServer(unittest.TestCase): def test_async_handlers(self, eio): s = async_server.AsyncServer(async_handlers=True) - s.manager.connect('123', '/') + _run(s.manager.connect('123', '/')) _run(s._handle_eio_message('123', '2["my message","a","b","c"]')) s.eio.start_background_task.assert_called_once_with( s._handle_event_internal,