From 308b0c8eeb71e1fead35d19088a3291a15ccd50a Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Fri, 4 Dec 2020 00:05:38 +0000 Subject: [PATCH] v5 protocol: handle per-namespace sids in base manager --- setup.py | 1 + socketio/asyncio_client.py | 4 +- socketio/asyncio_manager.py | 4 +- socketio/asyncio_server.py | 114 +++++++++---------- socketio/base_manager.py | 36 +++--- socketio/client.py | 4 +- socketio/server.py | 113 +++++++++--------- tests/common/test_base_manager.py | 183 +++++++++++++++--------------- 8 files changed, 229 insertions(+), 230 deletions(-) diff --git a/setup.py b/setup.py index f0d2f26..da83f00 100755 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ setup( platforms='any', install_requires=[ 'six>=1.9.0', + 'bidict>=0.21.0', 'python-engineio>=3.13.0,<4' ], extras_require={ diff --git a/socketio/asyncio_client.py b/socketio/asyncio_client.py index 615c7c0..029a1ea 100644 --- a/socketio/asyncio_client.py +++ b/socketio/asyncio_client.py @@ -353,7 +353,7 @@ class AsyncClient(client.Client): if namespace in self.namespaces: self.namespaces.remove(namespace) if namespace == '/': - self.namespaces = [] + self.namespaces = {} self.connected = False async def _trigger_event(self, event, namespace, *args): @@ -456,7 +456,7 @@ class AsyncClient(client.Client): if self.connected: for n in self.namespaces: await self._trigger_event('disconnect', namespace=n) - self.namespaces = [] + self.namespaces = {} self.connected = False self.callbacks = {} self._binary_packet = None diff --git a/socketio/asyncio_manager.py b/socketio/asyncio_manager.py index 55032ab..db9fecf 100644 --- a/socketio/asyncio_manager.py +++ b/socketio/asyncio_manager.py @@ -20,13 +20,13 @@ class AsyncManager(BaseManager): tasks = [] if not isinstance(skip_sid, list): skip_sid = [skip_sid] - for sid in self.get_participants(namespace, room): + for sid, eio_sid in self.get_participants(namespace, room): if sid not in skip_sid: if callback is not None: id = self._generate_ack_id(sid, namespace, callback) else: id = None - tasks.append(self.server._emit_internal(sid, event, data, + tasks.append(self.server._emit_internal(eio_sid, event, data, namespace, id)) if tasks == []: # pragma: no cover return diff --git a/socketio/asyncio_server.py b/socketio/asyncio_server.py index 2cf5d4c..5eaeb3a 100644 --- a/socketio/asyncio_server.py +++ b/socketio/asyncio_server.py @@ -336,13 +336,11 @@ class AsyncServer(server.Server): delete_it = await self.manager.can_disconnect(sid, namespace) if delete_it: self.logger.info('Disconnecting %s [%s]', sid, namespace) - self.manager.pre_disconnect(sid, namespace=namespace) - await self._send_packet(sid, packet.Packet(packet.DISCONNECT, - namespace=namespace)) + eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) + await self._send_packet(eio_sid, packet.Packet( + packet.DISCONNECT, namespace=namespace)) await self._trigger_event('disconnect', namespace, sid) self.manager.disconnect(sid, namespace=namespace) - if namespace == '/': - await self.eio.disconnect(sid) async def handle_request(self, *args, **kwargs): """Handle an HTTP request from the client. @@ -396,26 +394,26 @@ class AsyncServer(server.Server): await self._send_packet(sid, packet.Packet( packet.EVENT, namespace=namespace, data=[event] + data, id=id)) - async def _send_packet(self, sid, pkt): + async def _send_packet(self, eio_sid, pkt): """Send a Socket.IO packet to a client.""" encoded_packet = pkt.encode() if isinstance(encoded_packet, list): for ep in encoded_packet: - await self.eio.send(sid, ep) + await self.eio.send(eio_sid, ep) else: - await self.eio.send(sid, encoded_packet) + await self.eio.send(eio_sid, encoded_packet) - async def _handle_connect(self, sid, namespace): + async def _handle_connect(self, eio_sid, namespace): """Handle a client connection request.""" namespace = namespace or '/' - self.manager.connect(sid, namespace) + sid = self.manager.connect(eio_sid, namespace) if self.always_connect: - await self._send_packet(sid, packet.Packet(packet.CONNECT, - namespace=namespace)) + await self._send_packet(eio_sid, packet.Packet( + packet.CONNECT, {'sid': sid}, namespace=namespace)) fail_reason = None try: success = await self._trigger_event('connect', namespace, sid, - self.environ[sid]) + self.environ[eio_sid]) except exceptions.ConnectionRefusedError as exc: fail_reason = exc.error_args success = False @@ -423,40 +421,34 @@ class AsyncServer(server.Server): if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) - await self._send_packet(sid, packet.Packet( + await self._send_packet(eio_sid, packet.Packet( packet.DISCONNECT, data=fail_reason, namespace=namespace)) elif namespace != '/': - await self._send_packet(sid, packet.Packet( + await self._send_packet(eio_sid, packet.Packet( packet.CONNECT_ERROR, data=fail_reason, namespace=namespace)) self.manager.disconnect(sid, namespace) - if namespace == '/' and sid in self.environ: # pragma: no cover - del self.environ[sid] + if namespace == '/' and \ + eio_sid in self.environ: # pragma: no cover + del self.environ[eio_sid] return fail_reason or False elif not self.always_connect: - await self._send_packet(sid, packet.Packet(packet.CONNECT, - namespace=namespace)) + await self._send_packet(eio_sid, packet.Packet( + packet.CONNECT, {'sid': sid}, namespace=namespace)) - async def _handle_disconnect(self, sid, namespace): + async def _handle_disconnect(self, eio_sid, namespace): """Handle a client disconnect.""" namespace = namespace or '/' - if namespace == '/': - namespace_list = list(self.manager.get_namespaces()) - else: - namespace_list = [namespace] - for n in namespace_list: - if n != '/' and self.manager.is_connected(sid, n): - self.manager.pre_disconnect(sid, namespace=n) - await self._trigger_event('disconnect', n, sid) - self.manager.disconnect(sid, n) - if namespace == '/' and self.manager.is_connected(sid, namespace): - self.manager.pre_disconnect(sid, namespace='/') - await self._trigger_event('disconnect', '/', sid) - self.manager.disconnect(sid, '/') - - async def _handle_event(self, sid, namespace, id, data): + sid = self.manager.sid_from_eio_sid(eio_sid, namespace) + if self.manager.is_connected(sid, namespace): + self.manager.pre_disconnect(sid, namespace=namespace) + await self._trigger_event('disconnect', namespace, sid) + self.manager.disconnect(sid, namespace) + + async def _handle_event(self, eio_sid, namespace, id, data): """Handle an incoming client event.""" namespace = namespace or '/' + sid = self.manager.sid_from_eio_sid(eio_sid, namespace) self.logger.info('received event "%s" from %s [%s]', data[0], sid, namespace) if not self.manager.is_connected(sid, namespace): @@ -465,11 +457,13 @@ class AsyncServer(server.Server): return if self.async_handlers: self.start_background_task(self._handle_event_internal, self, sid, - data, namespace, id) + eio_sid, data, namespace, id) else: - await self._handle_event_internal(self, sid, data, namespace, id) + await self._handle_event_internal(self, sid, eio_sid, data, + namespace, id) - async def _handle_event_internal(self, server, sid, data, namespace, id): + async def _handle_event_internal(self, server, sid, eio_sid, data, + namespace, id): r = await server._trigger_event(data[0], namespace, sid, *data[1:]) if id is not None: # send ACK packet with the response returned by the handler @@ -480,13 +474,13 @@ class AsyncServer(server.Server): data = list(r) else: data = [r] - await server._send_packet(sid, packet.Packet(packet.ACK, - namespace=namespace, - id=id, data=data)) + await server._send_packet(eio_sid, packet.Packet( + packet.ACK, namespace=namespace, id=id, data=data)) - async def _handle_ack(self, sid, namespace, id, data): + async def _handle_ack(self, eio_sid, namespace, id, data): """Handle ACK packets from the client.""" namespace = namespace or '/' + sid = self.manager.sid_from_eio_sid(eio_sid, namespace) self.logger.info('received ack from %s [%s]', sid, namespace) await self.manager.trigger_callback(sid, namespace, id, data) @@ -509,48 +503,50 @@ class AsyncServer(server.Server): return await self.namespace_handlers[namespace].trigger_event( event, *args) - async def _handle_eio_connect(self, sid, environ): + async def _handle_eio_connect(self, eio_sid, environ): """Handle the Engine.IO connection event.""" if not self.manager_initialized: self.manager_initialized = True self.manager.initialize() - self.environ[sid] = environ + self.environ[eio_sid] = environ - async def _handle_eio_message(self, sid, data): + async def _handle_eio_message(self, eio_sid, data): """Dispatch Engine.IO messages.""" - if sid in self._binary_packet: - pkt = self._binary_packet[sid] + if eio_sid in self._binary_packet: + pkt = self._binary_packet[eio_sid] if pkt.add_attachment(data): - del self._binary_packet[sid] + del self._binary_packet[eio_sid] if pkt.packet_type == packet.BINARY_EVENT: - await self._handle_event(sid, pkt.namespace, pkt.id, + await self._handle_event(eio_sid, pkt.namespace, pkt.id, pkt.data) else: - await self._handle_ack(sid, pkt.namespace, pkt.id, + await self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) else: pkt = packet.Packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: - await self._handle_connect(sid, pkt.namespace) + await self._handle_connect(eio_sid, pkt.namespace) elif pkt.packet_type == packet.DISCONNECT: - await self._handle_disconnect(sid, pkt.namespace) + await self._handle_disconnect(eio_sid, pkt.namespace) elif pkt.packet_type == packet.EVENT: - await self._handle_event(sid, pkt.namespace, pkt.id, pkt.data) + await self._handle_event(eio_sid, pkt.namespace, pkt.id, + pkt.data) elif pkt.packet_type == packet.ACK: - await self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) + await self._handle_ack(eio_sid, pkt.namespace, pkt.id, + pkt.data) elif pkt.packet_type == packet.BINARY_EVENT or \ pkt.packet_type == packet.BINARY_ACK: - self._binary_packet[sid] = pkt + self._binary_packet[eio_sid] = pkt elif pkt.packet_type == packet.CONNECT_ERROR: raise ValueError('Unexpected CONNECT_ERROR packet.') else: raise ValueError('Unknown packet type.') - async def _handle_eio_disconnect(self, sid): + async def _handle_eio_disconnect(self, eio_sid): """Handle Engine.IO disconnect event.""" - await self._handle_disconnect(sid, '/') - if sid in self.environ: - del self.environ[sid] + await self._handle_disconnect(eio_sid, '/') + if eio_sid in self.environ: + del self.environ[eio_sid] def _engineio_server_class(self): return engineio.AsyncServer diff --git a/socketio/base_manager.py b/socketio/base_manager.py index 0926462..2053dc1 100644 --- a/socketio/base_manager.py +++ b/socketio/base_manager.py @@ -1,6 +1,7 @@ import itertools import logging +from bidict import bidict import six default_logger = logging.getLogger('socketio') @@ -18,7 +19,8 @@ class BaseManager(object): def __init__(self): self.logger = None self.server = None - self.rooms = {} + self.rooms = {} # self.rooms[namespace][room][sio_sid] = eio_sid + self.eio_to_sid = {} self.callbacks = {} self.pending_disconnect = {} @@ -37,13 +39,15 @@ class BaseManager(object): def get_participants(self, namespace, room): """Return an iterable with the active participants in a room.""" - for sid, active in six.iteritems(self.rooms[namespace][room].copy()): - yield sid + for sid, eio_sid in self.rooms[namespace][room].copy().items(): + yield sid, eio_sid - def connect(self, sid, namespace): + def connect(self, eio_sid, namespace): """Register a client connection to a namespace.""" - self.enter_room(sid, namespace, None) - self.enter_room(sid, namespace, sid) + sid = self.server.eio.generate_id() + self.enter_room(sid, namespace, None, eio_sid=eio_sid) + self.enter_room(sid, namespace, sid, eio_sid=eio_sid) + return sid def is_connected(self, sid, namespace): if namespace in self.pending_disconnect and \ @@ -55,6 +59,9 @@ class BaseManager(object): except KeyError: pass + def sid_from_eio_sid(self, eio_sid, namespace): + return self.rooms[namespace][None].inverse.get(eio_sid) + def can_disconnect(self, sid, namespace): return self.is_connected(sid, namespace) @@ -68,6 +75,7 @@ class BaseManager(object): if namespace not in self.pending_disconnect: self.pending_disconnect[namespace] = [] self.pending_disconnect[namespace].append(sid) + return self.rooms[namespace][None].get(sid) def disconnect(self, sid, namespace): """Register a client disconnect from a namespace.""" @@ -89,13 +97,15 @@ class BaseManager(object): if len(self.pending_disconnect[namespace]) == 0: del self.pending_disconnect[namespace] - def enter_room(self, sid, namespace, room): + def enter_room(self, sid, namespace, room, eio_sid=None): """Add a client to a room.""" if namespace not in self.rooms: self.rooms[namespace] = {} if room not in self.rooms[namespace]: - self.rooms[namespace][room] = {} - self.rooms[namespace][room][sid] = True + self.rooms[namespace][room] = bidict() + if eio_sid is None: + eio_sid = self.rooms[namespace][None][sid] + self.rooms[namespace][room][sid] = eio_sid def leave_room(self, sid, namespace, room): """Remove a client from a room.""" @@ -111,7 +121,7 @@ class BaseManager(object): def close_room(self, room, namespace): """Remove all participants from a room.""" try: - for sid in self.get_participants(namespace, room): + for sid, _ in self.get_participants(namespace, room): self.leave_room(sid, namespace, room) except KeyError: pass @@ -121,7 +131,7 @@ class BaseManager(object): r = [] try: for room_name, room in six.iteritems(self.rooms[namespace]): - if room_name is not None and sid in room and room[sid]: + if room_name is not None and sid in room: r.append(room_name) except KeyError: pass @@ -135,13 +145,13 @@ class BaseManager(object): return if not isinstance(skip_sid, list): skip_sid = [skip_sid] - for sid in self.get_participants(namespace, room): + for sid, eio_sid in self.get_participants(namespace, room): if sid not in skip_sid: if callback is not None: id = self._generate_ack_id(sid, namespace, callback) else: id = None - self.server._emit_internal(sid, event, data, namespace, id) + self.server._emit_internal(eio_sid, event, data, namespace, id) def trigger_callback(self, sid, namespace, id, data): """Invoke an application callback.""" diff --git a/socketio/client.py b/socketio/client.py index 5873a5b..9c5aa12 100644 --- a/socketio/client.py +++ b/socketio/client.py @@ -533,7 +533,7 @@ class Client(object): if namespace in self.namespaces: del self.namespaces[namespace] if namespace == '/': - self.namespaces = [] + self.namespaces = {} self.connected = False def _trigger_event(self, event, namespace, *args): @@ -625,7 +625,7 @@ class Client(object): if self.connected: for n in self.namespaces: self._trigger_event('disconnect', namespace=n) - self.namespaces = [] + self.namespaces = {} self.connected = False self.callbacks = {} self._binary_packet = None diff --git a/socketio/server.py b/socketio/server.py index d6bc7e3..b842f40 100644 --- a/socketio/server.py +++ b/socketio/server.py @@ -520,13 +520,11 @@ class Server(object): delete_it = self.manager.can_disconnect(sid, namespace) if delete_it: self.logger.info('Disconnecting %s [%s]', sid, namespace) - self.manager.pre_disconnect(sid, namespace=namespace) - self._send_packet(sid, packet.Packet(packet.DISCONNECT, - namespace=namespace)) + eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) + self._send_packet(eio_sid, packet.Packet( + packet.DISCONNECT, namespace=namespace)) self._trigger_event('disconnect', namespace, sid) self.manager.disconnect(sid, namespace=namespace) - if namespace == '/': - self.eio.disconnect(sid) def transport(self, sid): """Return the name of the transport used by the client. @@ -594,26 +592,26 @@ class Server(object): self._send_packet(sid, packet.Packet(packet.EVENT, namespace=namespace, data=[event] + data, id=id)) - def _send_packet(self, sid, pkt): + def _send_packet(self, eio_sid, pkt): """Send a Socket.IO packet to a client.""" encoded_packet = pkt.encode() if isinstance(encoded_packet, list): for ep in encoded_packet: - self.eio.send(sid, ep) + self.eio.send(eio_sid, ep) else: - self.eio.send(sid, encoded_packet) + self.eio.send(eio_sid, encoded_packet) - def _handle_connect(self, sid, namespace): + def _handle_connect(self, eio_sid, namespace): """Handle a client connection request.""" namespace = namespace or '/' - self.manager.connect(sid, namespace) + sid = self.manager.connect(eio_sid, namespace) if self.always_connect: - self._send_packet(sid, packet.Packet(packet.CONNECT, - namespace=namespace)) + self._send_packet(eio_sid, packet.Packet( + packet.CONNECT, {'sid': sid}, namespace=namespace)) fail_reason = None try: success = self._trigger_event('connect', namespace, sid, - self.environ[sid]) + self.environ[eio_sid]) except exceptions.ConnectionRefusedError as exc: fail_reason = exc.error_args success = False @@ -621,40 +619,34 @@ class Server(object): if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) - self._send_packet(sid, packet.Packet( + self._send_packet(eio_sid, packet.Packet( packet.DISCONNECT, data=fail_reason, namespace=namespace)) elif namespace != '/': - self._send_packet(sid, packet.Packet( + self._send_packet(eio_sid, packet.Packet( packet.CONNECT_ERROR, data=fail_reason, namespace=namespace)) self.manager.disconnect(sid, namespace) - if namespace == '/' and sid in self.environ: # pragma: no cover - del self.environ[sid] + if namespace == '/' and \ + eio_sid in self.environ: # pragma: no cover + del self.environ[eio_sid] return fail_reason or False elif not self.always_connect: - self._send_packet(sid, packet.Packet(packet.CONNECT, - namespace=namespace)) + self._send_packet(eio_sid, packet.Packet( + packet.CONNECT, {'sid': sid}, namespace=namespace)) - def _handle_disconnect(self, sid, namespace): + def _handle_disconnect(self, eio_sid, namespace): """Handle a client disconnect.""" namespace = namespace or '/' - if namespace == '/': - namespace_list = list(self.manager.get_namespaces()) - else: - namespace_list = [namespace] - for n in namespace_list: - if n != '/' and self.manager.is_connected(sid, n): - self.manager.pre_disconnect(sid, namespace=n) - self._trigger_event('disconnect', n, sid) - self.manager.disconnect(sid, n) - if namespace == '/' and self.manager.is_connected(sid, namespace): - self.manager.pre_disconnect(sid, namespace='/') - self._trigger_event('disconnect', '/', sid) - self.manager.disconnect(sid, '/') - - def _handle_event(self, sid, namespace, id, data): + sid = self.manager.sid_from_eio_sid(eio_sid, namespace) + if self.manager.is_connected(sid, namespace): + self.manager.pre_disconnect(sid, namespace=namespace) + self._trigger_event('disconnect', namespace, sid) + self.manager.disconnect(sid, namespace) + + def _handle_event(self, eio_sid, namespace, id, data): """Handle an incoming client event.""" namespace = namespace or '/' + sid = self.manager.sid_from_eio_sid(eio_sid, namespace) self.logger.info('received event "%s" from %s [%s]', data[0], sid, namespace) if not self.manager.is_connected(sid, namespace): @@ -663,11 +655,13 @@ class Server(object): return if self.async_handlers: self.start_background_task(self._handle_event_internal, self, sid, - data, namespace, id) + eio_sid, data, namespace, id) else: - self._handle_event_internal(self, sid, data, namespace, id) + self._handle_event_internal(self, sid, eio_sid, data, namespace, + id) - def _handle_event_internal(self, server, sid, data, namespace, id): + def _handle_event_internal(self, server, sid, eio_sid, data, namespace, + id): r = server._trigger_event(data[0], namespace, sid, *data[1:]) if id is not None: # send ACK packet with the response returned by the handler @@ -678,13 +672,13 @@ class Server(object): data = list(r) else: data = [r] - server._send_packet(sid, packet.Packet(packet.ACK, - namespace=namespace, - id=id, data=data)) + server._send_packet(eio_sid, packet.Packet( + packet.ACK, namespace=namespace, id=id, data=data)) - def _handle_ack(self, sid, namespace, id, data): + def _handle_ack(self, eio_sid, namespace, id, data): """Handle ACK packets from the client.""" namespace = namespace or '/' + sid = self.manager.sid_from_eio_sid(eio_sid, namespace) self.logger.info('received ack from %s [%s]', sid, namespace) self.manager.trigger_callback(sid, namespace, id, data) @@ -699,46 +693,47 @@ class Server(object): return self.namespace_handlers[namespace].trigger_event( event, *args) - def _handle_eio_connect(self, sid, environ): + def _handle_eio_connect(self, eio_sid, environ): """Handle the Engine.IO connection event.""" if not self.manager_initialized: self.manager_initialized = True self.manager.initialize() - self.environ[sid] = environ + self.environ[eio_sid] = environ - def _handle_eio_message(self, sid, data): + def _handle_eio_message(self, eio_sid, data): """Dispatch Engine.IO messages.""" - if sid in self._binary_packet: - pkt = self._binary_packet[sid] + if eio_sid in self._binary_packet: + pkt = self._binary_packet[eio_sid] if pkt.add_attachment(data): - del self._binary_packet[sid] + del self._binary_packet[eio_sid] if pkt.packet_type == packet.BINARY_EVENT: - self._handle_event(sid, pkt.namespace, pkt.id, pkt.data) + self._handle_event(eio_sid, pkt.namespace, pkt.id, + pkt.data) else: - self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) + self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) else: pkt = packet.Packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: - self._handle_connect(sid, pkt.namespace) + self._handle_connect(eio_sid, pkt.namespace) elif pkt.packet_type == packet.DISCONNECT: - self._handle_disconnect(sid, pkt.namespace) + self._handle_disconnect(eio_sid, pkt.namespace) elif pkt.packet_type == packet.EVENT: - self._handle_event(sid, pkt.namespace, pkt.id, pkt.data) + self._handle_event(eio_sid, pkt.namespace, pkt.id, pkt.data) elif pkt.packet_type == packet.ACK: - self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) + self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) elif pkt.packet_type == packet.BINARY_EVENT or \ pkt.packet_type == packet.BINARY_ACK: - self._binary_packet[sid] = pkt + self._binary_packet[eio_sid] = pkt elif pkt.packet_type == packet.CONNECT_ERROR: raise ValueError('Unexpected CONNECT_ERROR packet.') else: raise ValueError('Unknown packet type.') - def _handle_eio_disconnect(self, sid): + def _handle_eio_disconnect(self, eio_sid): """Handle Engine.IO disconnect event.""" - self._handle_disconnect(sid, '/') - if sid in self.environ: - del self.environ[sid] + self._handle_disconnect(eio_sid, '/') + if eio_sid in self.environ: + del self.environ[eio_sid] def _engineio_server_class(self): return engineio.Server diff --git a/tests/common/test_base_manager.py b/tests/common/test_base_manager.py index 4425ad3..479b80f 100644 --- a/tests/common/test_base_manager.py +++ b/tests/common/test_base_manager.py @@ -12,21 +12,28 @@ from socketio import base_manager class TestBaseManager(unittest.TestCase): def setUp(self): + id = 0 + + def generate_id(): + nonlocal id + id += 1 + return str(id) + mock_server = mock.MagicMock() + mock_server.eio.generate_id = generate_id self.bm = base_manager.BaseManager() self.bm.set_server(mock_server) self.bm.initialize() def test_connect(self): - self.bm.connect('123', '/foo') + sid = self.bm.connect('123', '/foo') assert None in self.bm.rooms['/foo'] - assert '123' in self.bm.rooms['/foo'] - assert '123' in self.bm.rooms['/foo'][None] - assert '123' in self.bm.rooms['/foo']['123'] - assert self.bm.rooms['/foo'] == { - None: {'123': True}, - '123': {'123': True}, - } + assert sid in self.bm.rooms['/foo'] + assert sid in self.bm.rooms['/foo'][None] + assert sid in self.bm.rooms['/foo'][sid] + assert dict(self.bm.rooms['/foo'][None]) == {sid: '123'} + assert dict(self.bm.rooms['/foo'][sid]) == {sid: '123'} + assert self.bm.sid_from_eio_sid('123', '/foo') == sid def test_pre_disconnect(self): self.bm.connect('123', '/foo') @@ -43,63 +50,53 @@ class TestBaseManager(unittest.TestCase): assert self.bm.pending_disconnect == {} def test_disconnect(self): - self.bm.connect('123', '/foo') - self.bm.connect('456', '/foo') - self.bm.enter_room('123', '/foo', 'bar') - self.bm.enter_room('456', '/foo', 'baz') - self.bm.disconnect('123', '/foo') - assert self.bm.rooms['/foo'] == { - None: {'456': True}, - '456': {'456': True}, - 'baz': {'456': True}, - } + sid1 = self.bm.connect('123', '/foo') + sid2 = self.bm.connect('456', '/foo') + self.bm.enter_room(sid1, '/foo', 'bar') + self.bm.enter_room(sid2, '/foo', 'baz') + self.bm.disconnect(sid1, '/foo') + assert dict(self.bm.rooms['/foo'][None]) == {sid2: '456'} + assert dict(self.bm.rooms['/foo'][sid2]) == {sid2: '456'} + assert dict(self.bm.rooms['/foo']['baz']) == {sid2: '456'} def test_disconnect_default_namespace(self): - self.bm.connect('123', '/') - self.bm.connect('123', '/foo') - self.bm.connect('456', '/') - self.bm.connect('456', '/foo') - assert self.bm.is_connected('123', '/') - assert self.bm.is_connected('123', '/foo') - self.bm.disconnect('123', '/') - assert not self.bm.is_connected('123', '/') - assert self.bm.is_connected('123', '/foo') - self.bm.disconnect('123', '/foo') - assert not self.bm.is_connected('123', '/foo') - assert self.bm.rooms['/'] == { - None: {'456': True}, - '456': {'456': True}, - } - assert self.bm.rooms['/foo'] == { - None: {'456': True}, - '456': {'456': True}, - } + sid1 = self.bm.connect('123', '/') + sid2 = self.bm.connect('123', '/foo') + sid3 = self.bm.connect('456', '/') + sid4 = self.bm.connect('456', '/foo') + assert self.bm.is_connected(sid1, '/') + assert self.bm.is_connected(sid2, '/foo') + self.bm.disconnect(sid1, '/') + assert not self.bm.is_connected(sid1, '/') + assert self.bm.is_connected(sid2, '/foo') + self.bm.disconnect(sid2, '/foo') + assert not self.bm.is_connected(sid2, '/foo') + assert dict(self.bm.rooms['/'][None]) == {sid3: '456'} + assert dict(self.bm.rooms['/'][sid3]) == {sid3: '456'} + assert dict(self.bm.rooms['/foo'][None]) == {sid4: '456'} + assert dict(self.bm.rooms['/foo'][sid4]) == {sid4: '456'} def test_disconnect_twice(self): - self.bm.connect('123', '/') - self.bm.connect('123', '/foo') - self.bm.connect('456', '/') - self.bm.connect('456', '/foo') - self.bm.disconnect('123', '/') - self.bm.disconnect('123', '/foo') - self.bm.disconnect('123', '/') - self.bm.disconnect('123', '/foo') - assert self.bm.rooms['/'] == { - None: {'456': True}, - '456': {'456': True}, - } - assert self.bm.rooms['/foo'] == { - None: {'456': True}, - '456': {'456': True}, - } + sid1 = self.bm.connect('123', '/') + sid2 = self.bm.connect('123', '/foo') + sid3 = self.bm.connect('456', '/') + sid4 = self.bm.connect('456', '/foo') + self.bm.disconnect(sid1, '/') + self.bm.disconnect(sid2, '/foo') + self.bm.disconnect(sid1, '/') + self.bm.disconnect(sid2, '/foo') + assert dict(self.bm.rooms['/'][None]) == {sid3: '456'} + assert dict(self.bm.rooms['/'][sid3]) == {sid3: '456'} + assert dict(self.bm.rooms['/foo'][None]) == {sid4: '456'} + assert dict(self.bm.rooms['/foo'][sid4]) == {sid4: '456'} def test_disconnect_all(self): - self.bm.connect('123', '/foo') - self.bm.connect('456', '/foo') - self.bm.enter_room('123', '/foo', 'bar') - self.bm.enter_room('456', '/foo', 'baz') - self.bm.disconnect('123', '/foo') - self.bm.disconnect('456', '/foo') + sid1 = self.bm.connect('123', '/foo') + sid2 = self.bm.connect('456', '/foo') + self.bm.enter_room(sid1, '/foo', 'bar') + self.bm.enter_room(sid2, '/foo', 'baz') + self.bm.disconnect(sid1, '/foo') + self.bm.disconnect(sid2, '/foo') assert self.bm.rooms == {} def test_disconnect_with_callbacks(self): @@ -152,12 +149,12 @@ class TestBaseManager(unittest.TestCase): def test_get_participants(self): self.bm.connect('123', '/') self.bm.connect('456', '/') - self.bm.connect('789', '/') - self.bm.disconnect('789', '/') - assert '789' not in self.bm.rooms['/'][None] + sid = self.bm.connect('789', '/') + self.bm.disconnect(sid, '/') + assert sid not in self.bm.rooms['/'][None] participants = list(self.bm.get_participants('/', None)) assert len(participants) == 2 - assert '789' not in participants + assert sid not in participants def test_leave_invalid_room(self): self.bm.connect('123', '/foo') @@ -169,11 +166,11 @@ class TestBaseManager(unittest.TestCase): assert [] == rooms def test_close_room(self): - self.bm.connect('123', '/foo') + sid1 = self.bm.connect('123', '/foo') self.bm.connect('456', '/foo') self.bm.connect('789', '/foo') - self.bm.enter_room('123', '/foo', 'bar') - self.bm.enter_room('123', '/foo', 'bar') + self.bm.enter_room(sid1, '/foo', 'bar') + self.bm.enter_room(sid1, '/foo', 'bar') self.bm.close_room('bar', '/foo') assert 'bar' not in self.bm.rooms['/foo'] @@ -181,26 +178,26 @@ class TestBaseManager(unittest.TestCase): self.bm.close_room('bar', '/foo') def test_rooms(self): - self.bm.connect('123', '/foo') - self.bm.enter_room('123', '/foo', 'bar') - r = self.bm.get_rooms('123', '/foo') + sid = self.bm.connect('123', '/foo') + self.bm.enter_room(sid, '/foo', 'bar') + r = self.bm.get_rooms(sid, '/foo') assert len(r) == 2 - assert '123' in r + assert sid in r assert 'bar' in r def test_emit_to_sid(self): - self.bm.connect('123', '/foo') + sid = self.bm.connect('123', '/foo') self.bm.connect('456', '/foo') - self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo', room='123') + self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo', room=sid) self.bm.server._emit_internal.assert_called_once_with( '123', 'my event', {'foo': 'bar'}, '/foo', None ) def test_emit_to_room(self): - self.bm.connect('123', '/foo') - self.bm.enter_room('123', '/foo', 'bar') - self.bm.connect('456', '/foo') - self.bm.enter_room('456', '/foo', 'bar') + sid1 = self.bm.connect('123', '/foo') + self.bm.enter_room(sid1, '/foo', 'bar') + sid2 = self.bm.connect('456', '/foo') + self.bm.enter_room(sid2, '/foo', 'bar') self.bm.connect('789', '/foo') self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo', room='bar') assert self.bm.server._emit_internal.call_count == 2 @@ -212,10 +209,10 @@ class TestBaseManager(unittest.TestCase): ) def test_emit_to_all(self): - self.bm.connect('123', '/foo') - self.bm.enter_room('123', '/foo', 'bar') - self.bm.connect('456', '/foo') - self.bm.enter_room('456', '/foo', 'bar') + sid1 = self.bm.connect('123', '/foo') + self.bm.enter_room(sid1, '/foo', 'bar') + sid2 = self.bm.connect('456', '/foo') + self.bm.enter_room(sid2, '/foo', 'bar') self.bm.connect('789', '/foo') self.bm.connect('abc', '/bar') self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo') @@ -231,14 +228,14 @@ class TestBaseManager(unittest.TestCase): ) def test_emit_to_all_skip_one(self): - self.bm.connect('123', '/foo') - self.bm.enter_room('123', '/foo', 'bar') - self.bm.connect('456', '/foo') - self.bm.enter_room('456', '/foo', 'bar') + sid1 = self.bm.connect('123', '/foo') + self.bm.enter_room(sid1, '/foo', 'bar') + sid2 = self.bm.connect('456', '/foo') + self.bm.enter_room(sid2, '/foo', 'bar') self.bm.connect('789', '/foo') self.bm.connect('abc', '/bar') self.bm.emit( - 'my event', {'foo': 'bar'}, namespace='/foo', skip_sid='456' + 'my event', {'foo': 'bar'}, namespace='/foo', skip_sid=sid2 ) assert self.bm.server._emit_internal.call_count == 2 self.bm.server._emit_internal.assert_any_call( @@ -249,17 +246,17 @@ class TestBaseManager(unittest.TestCase): ) def test_emit_to_all_skip_two(self): - self.bm.connect('123', '/foo') - self.bm.enter_room('123', '/foo', 'bar') - self.bm.connect('456', '/foo') - self.bm.enter_room('456', '/foo', 'bar') - self.bm.connect('789', '/foo') + sid1 = self.bm.connect('123', '/foo') + self.bm.enter_room(sid1, '/foo', 'bar') + sid2 = self.bm.connect('456', '/foo') + self.bm.enter_room(sid2, '/foo', 'bar') + sid3 = self.bm.connect('789', '/foo') self.bm.connect('abc', '/bar') self.bm.emit( 'my event', {'foo': 'bar'}, namespace='/foo', - skip_sid=['123', '789'], + skip_sid=[sid1, sid3], ) assert self.bm.server._emit_internal.call_count == 1 self.bm.server._emit_internal.assert_any_call( @@ -267,13 +264,13 @@ class TestBaseManager(unittest.TestCase): ) def test_emit_with_callback(self): - self.bm.connect('123', '/foo') + sid = self.bm.connect('123', '/foo') self.bm._generate_ack_id = mock.MagicMock() self.bm._generate_ack_id.return_value = 11 self.bm.emit( 'my event', {'foo': 'bar'}, namespace='/foo', callback='cb' ) - self.bm._generate_ack_id.assert_called_once_with('123', '/foo', 'cb') + self.bm._generate_ack_id.assert_called_once_with(sid, '/foo', 'cb') self.bm.server._emit_internal.assert_called_once_with( '123', 'my event', {'foo': 'bar'}, '/foo', 11 )