Browse Source

v5 protocol: handle per-namespace sids in base manager

pull/599/head
Miguel Grinberg 4 years ago
parent
commit
308b0c8eeb
No known key found for this signature in database GPG Key ID: 36848B262DF5F06C
  1. 1
      setup.py
  2. 4
      socketio/asyncio_client.py
  3. 4
      socketio/asyncio_manager.py
  4. 114
      socketio/asyncio_server.py
  5. 36
      socketio/base_manager.py
  6. 4
      socketio/client.py
  7. 113
      socketio/server.py
  8. 183
      tests/common/test_base_manager.py

1
setup.py

@ -30,6 +30,7 @@ setup(
platforms='any',
install_requires=[
'six>=1.9.0',
'bidict>=0.21.0',
'python-engineio>=3.13.0,<4'
],
extras_require={

4
socketio/asyncio_client.py

@ -353,7 +353,7 @@ class AsyncClient(client.Client):
if namespace in self.namespaces:
self.namespaces.remove(namespace)
if namespace == '/':
self.namespaces = []
self.namespaces = {}
self.connected = False
async def _trigger_event(self, event, namespace, *args):
@ -456,7 +456,7 @@ class AsyncClient(client.Client):
if self.connected:
for n in self.namespaces:
await self._trigger_event('disconnect', namespace=n)
self.namespaces = []
self.namespaces = {}
self.connected = False
self.callbacks = {}
self._binary_packet = None

4
socketio/asyncio_manager.py

@ -20,13 +20,13 @@ class AsyncManager(BaseManager):
tasks = []
if not isinstance(skip_sid, list):
skip_sid = [skip_sid]
for sid in self.get_participants(namespace, room):
for sid, eio_sid in self.get_participants(namespace, room):
if sid not in skip_sid:
if callback is not None:
id = self._generate_ack_id(sid, namespace, callback)
else:
id = None
tasks.append(self.server._emit_internal(sid, event, data,
tasks.append(self.server._emit_internal(eio_sid, event, data,
namespace, id))
if tasks == []: # pragma: no cover
return

114
socketio/asyncio_server.py

@ -336,13 +336,11 @@ class AsyncServer(server.Server):
delete_it = await self.manager.can_disconnect(sid, namespace)
if delete_it:
self.logger.info('Disconnecting %s [%s]', sid, namespace)
self.manager.pre_disconnect(sid, namespace=namespace)
await self._send_packet(sid, packet.Packet(packet.DISCONNECT,
namespace=namespace))
eio_sid = self.manager.pre_disconnect(sid, namespace=namespace)
await self._send_packet(eio_sid, packet.Packet(
packet.DISCONNECT, namespace=namespace))
await self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace)
if namespace == '/':
await self.eio.disconnect(sid)
async def handle_request(self, *args, **kwargs):
"""Handle an HTTP request from the client.
@ -396,26 +394,26 @@ class AsyncServer(server.Server):
await self._send_packet(sid, packet.Packet(
packet.EVENT, namespace=namespace, data=[event] + data, id=id))
async def _send_packet(self, sid, pkt):
async def _send_packet(self, eio_sid, pkt):
"""Send a Socket.IO packet to a client."""
encoded_packet = pkt.encode()
if isinstance(encoded_packet, list):
for ep in encoded_packet:
await self.eio.send(sid, ep)
await self.eio.send(eio_sid, ep)
else:
await self.eio.send(sid, encoded_packet)
await self.eio.send(eio_sid, encoded_packet)
async def _handle_connect(self, sid, namespace):
async def _handle_connect(self, eio_sid, namespace):
"""Handle a client connection request."""
namespace = namespace or '/'
self.manager.connect(sid, namespace)
sid = self.manager.connect(eio_sid, namespace)
if self.always_connect:
await self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace))
await self._send_packet(eio_sid, packet.Packet(
packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = None
try:
success = await self._trigger_event('connect', namespace, sid,
self.environ[sid])
self.environ[eio_sid])
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args
success = False
@ -423,40 +421,34 @@ class AsyncServer(server.Server):
if success is False:
if self.always_connect:
self.manager.pre_disconnect(sid, namespace)
await self._send_packet(sid, packet.Packet(
await self._send_packet(eio_sid, packet.Packet(
packet.DISCONNECT, data=fail_reason, namespace=namespace))
elif namespace != '/':
await self._send_packet(sid, packet.Packet(
await self._send_packet(eio_sid, packet.Packet(
packet.CONNECT_ERROR, data=fail_reason,
namespace=namespace))
self.manager.disconnect(sid, namespace)
if namespace == '/' and sid in self.environ: # pragma: no cover
del self.environ[sid]
if namespace == '/' and \
eio_sid in self.environ: # pragma: no cover
del self.environ[eio_sid]
return fail_reason or False
elif not self.always_connect:
await self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace))
await self._send_packet(eio_sid, packet.Packet(
packet.CONNECT, {'sid': sid}, namespace=namespace))
async def _handle_disconnect(self, sid, namespace):
async def _handle_disconnect(self, eio_sid, namespace):
"""Handle a client disconnect."""
namespace = namespace or '/'
if namespace == '/':
namespace_list = list(self.manager.get_namespaces())
else:
namespace_list = [namespace]
for n in namespace_list:
if n != '/' and self.manager.is_connected(sid, n):
self.manager.pre_disconnect(sid, namespace=n)
await self._trigger_event('disconnect', n, sid)
self.manager.disconnect(sid, n)
if namespace == '/' and self.manager.is_connected(sid, namespace):
self.manager.pre_disconnect(sid, namespace='/')
await self._trigger_event('disconnect', '/', sid)
self.manager.disconnect(sid, '/')
async def _handle_event(self, sid, namespace, id, data):
sid = self.manager.sid_from_eio_sid(eio_sid, namespace)
if self.manager.is_connected(sid, namespace):
self.manager.pre_disconnect(sid, namespace=namespace)
await self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace)
async def _handle_event(self, eio_sid, namespace, id, data):
"""Handle an incoming client event."""
namespace = namespace or '/'
sid = self.manager.sid_from_eio_sid(eio_sid, namespace)
self.logger.info('received event "%s" from %s [%s]', data[0], sid,
namespace)
if not self.manager.is_connected(sid, namespace):
@ -465,11 +457,13 @@ class AsyncServer(server.Server):
return
if self.async_handlers:
self.start_background_task(self._handle_event_internal, self, sid,
data, namespace, id)
eio_sid, data, namespace, id)
else:
await self._handle_event_internal(self, sid, data, namespace, id)
await self._handle_event_internal(self, sid, eio_sid, data,
namespace, id)
async def _handle_event_internal(self, server, sid, data, namespace, id):
async def _handle_event_internal(self, server, sid, eio_sid, data,
namespace, id):
r = await server._trigger_event(data[0], namespace, sid, *data[1:])
if id is not None:
# send ACK packet with the response returned by the handler
@ -480,13 +474,13 @@ class AsyncServer(server.Server):
data = list(r)
else:
data = [r]
await server._send_packet(sid, packet.Packet(packet.ACK,
namespace=namespace,
id=id, data=data))
await server._send_packet(eio_sid, packet.Packet(
packet.ACK, namespace=namespace, id=id, data=data))
async def _handle_ack(self, sid, namespace, id, data):
async def _handle_ack(self, eio_sid, namespace, id, data):
"""Handle ACK packets from the client."""
namespace = namespace or '/'
sid = self.manager.sid_from_eio_sid(eio_sid, namespace)
self.logger.info('received ack from %s [%s]', sid, namespace)
await self.manager.trigger_callback(sid, namespace, id, data)
@ -509,48 +503,50 @@ class AsyncServer(server.Server):
return await self.namespace_handlers[namespace].trigger_event(
event, *args)
async def _handle_eio_connect(self, sid, environ):
async def _handle_eio_connect(self, eio_sid, environ):
"""Handle the Engine.IO connection event."""
if not self.manager_initialized:
self.manager_initialized = True
self.manager.initialize()
self.environ[sid] = environ
self.environ[eio_sid] = environ
async def _handle_eio_message(self, sid, data):
async def _handle_eio_message(self, eio_sid, data):
"""Dispatch Engine.IO messages."""
if sid in self._binary_packet:
pkt = self._binary_packet[sid]
if eio_sid in self._binary_packet:
pkt = self._binary_packet[eio_sid]
if pkt.add_attachment(data):
del self._binary_packet[sid]
del self._binary_packet[eio_sid]
if pkt.packet_type == packet.BINARY_EVENT:
await self._handle_event(sid, pkt.namespace, pkt.id,
await self._handle_event(eio_sid, pkt.namespace, pkt.id,
pkt.data)
else:
await self._handle_ack(sid, pkt.namespace, pkt.id,
await self._handle_ack(eio_sid, pkt.namespace, pkt.id,
pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
await self._handle_connect(sid, pkt.namespace)
await self._handle_connect(eio_sid, pkt.namespace)
elif pkt.packet_type == packet.DISCONNECT:
await self._handle_disconnect(sid, pkt.namespace)
await self._handle_disconnect(eio_sid, pkt.namespace)
elif pkt.packet_type == packet.EVENT:
await self._handle_event(sid, pkt.namespace, pkt.id, pkt.data)
await self._handle_event(eio_sid, pkt.namespace, pkt.id,
pkt.data)
elif pkt.packet_type == packet.ACK:
await self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
await self._handle_ack(eio_sid, pkt.namespace, pkt.id,
pkt.data)
elif pkt.packet_type == packet.BINARY_EVENT or \
pkt.packet_type == packet.BINARY_ACK:
self._binary_packet[sid] = pkt
self._binary_packet[eio_sid] = pkt
elif pkt.packet_type == packet.CONNECT_ERROR:
raise ValueError('Unexpected CONNECT_ERROR packet.')
else:
raise ValueError('Unknown packet type.')
async def _handle_eio_disconnect(self, sid):
async def _handle_eio_disconnect(self, eio_sid):
"""Handle Engine.IO disconnect event."""
await self._handle_disconnect(sid, '/')
if sid in self.environ:
del self.environ[sid]
await self._handle_disconnect(eio_sid, '/')
if eio_sid in self.environ:
del self.environ[eio_sid]
def _engineio_server_class(self):
return engineio.AsyncServer

36
socketio/base_manager.py

@ -1,6 +1,7 @@
import itertools
import logging
from bidict import bidict
import six
default_logger = logging.getLogger('socketio')
@ -18,7 +19,8 @@ class BaseManager(object):
def __init__(self):
self.logger = None
self.server = None
self.rooms = {}
self.rooms = {} # self.rooms[namespace][room][sio_sid] = eio_sid
self.eio_to_sid = {}
self.callbacks = {}
self.pending_disconnect = {}
@ -37,13 +39,15 @@ class BaseManager(object):
def get_participants(self, namespace, room):
"""Return an iterable with the active participants in a room."""
for sid, active in six.iteritems(self.rooms[namespace][room].copy()):
yield sid
for sid, eio_sid in self.rooms[namespace][room].copy().items():
yield sid, eio_sid
def connect(self, sid, namespace):
def connect(self, eio_sid, namespace):
"""Register a client connection to a namespace."""
self.enter_room(sid, namespace, None)
self.enter_room(sid, namespace, sid)
sid = self.server.eio.generate_id()
self.enter_room(sid, namespace, None, eio_sid=eio_sid)
self.enter_room(sid, namespace, sid, eio_sid=eio_sid)
return sid
def is_connected(self, sid, namespace):
if namespace in self.pending_disconnect and \
@ -55,6 +59,9 @@ class BaseManager(object):
except KeyError:
pass
def sid_from_eio_sid(self, eio_sid, namespace):
return self.rooms[namespace][None].inverse.get(eio_sid)
def can_disconnect(self, sid, namespace):
return self.is_connected(sid, namespace)
@ -68,6 +75,7 @@ class BaseManager(object):
if namespace not in self.pending_disconnect:
self.pending_disconnect[namespace] = []
self.pending_disconnect[namespace].append(sid)
return self.rooms[namespace][None].get(sid)
def disconnect(self, sid, namespace):
"""Register a client disconnect from a namespace."""
@ -89,13 +97,15 @@ class BaseManager(object):
if len(self.pending_disconnect[namespace]) == 0:
del self.pending_disconnect[namespace]
def enter_room(self, sid, namespace, room):
def enter_room(self, sid, namespace, room, eio_sid=None):
"""Add a client to a room."""
if namespace not in self.rooms:
self.rooms[namespace] = {}
if room not in self.rooms[namespace]:
self.rooms[namespace][room] = {}
self.rooms[namespace][room][sid] = True
self.rooms[namespace][room] = bidict()
if eio_sid is None:
eio_sid = self.rooms[namespace][None][sid]
self.rooms[namespace][room][sid] = eio_sid
def leave_room(self, sid, namespace, room):
"""Remove a client from a room."""
@ -111,7 +121,7 @@ class BaseManager(object):
def close_room(self, room, namespace):
"""Remove all participants from a room."""
try:
for sid in self.get_participants(namespace, room):
for sid, _ in self.get_participants(namespace, room):
self.leave_room(sid, namespace, room)
except KeyError:
pass
@ -121,7 +131,7 @@ class BaseManager(object):
r = []
try:
for room_name, room in six.iteritems(self.rooms[namespace]):
if room_name is not None and sid in room and room[sid]:
if room_name is not None and sid in room:
r.append(room_name)
except KeyError:
pass
@ -135,13 +145,13 @@ class BaseManager(object):
return
if not isinstance(skip_sid, list):
skip_sid = [skip_sid]
for sid in self.get_participants(namespace, room):
for sid, eio_sid in self.get_participants(namespace, room):
if sid not in skip_sid:
if callback is not None:
id = self._generate_ack_id(sid, namespace, callback)
else:
id = None
self.server._emit_internal(sid, event, data, namespace, id)
self.server._emit_internal(eio_sid, event, data, namespace, id)
def trigger_callback(self, sid, namespace, id, data):
"""Invoke an application callback."""

4
socketio/client.py

@ -533,7 +533,7 @@ class Client(object):
if namespace in self.namespaces:
del self.namespaces[namespace]
if namespace == '/':
self.namespaces = []
self.namespaces = {}
self.connected = False
def _trigger_event(self, event, namespace, *args):
@ -625,7 +625,7 @@ class Client(object):
if self.connected:
for n in self.namespaces:
self._trigger_event('disconnect', namespace=n)
self.namespaces = []
self.namespaces = {}
self.connected = False
self.callbacks = {}
self._binary_packet = None

113
socketio/server.py

@ -520,13 +520,11 @@ class Server(object):
delete_it = self.manager.can_disconnect(sid, namespace)
if delete_it:
self.logger.info('Disconnecting %s [%s]', sid, namespace)
self.manager.pre_disconnect(sid, namespace=namespace)
self._send_packet(sid, packet.Packet(packet.DISCONNECT,
namespace=namespace))
eio_sid = self.manager.pre_disconnect(sid, namespace=namespace)
self._send_packet(eio_sid, packet.Packet(
packet.DISCONNECT, namespace=namespace))
self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace)
if namespace == '/':
self.eio.disconnect(sid)
def transport(self, sid):
"""Return the name of the transport used by the client.
@ -594,26 +592,26 @@ class Server(object):
self._send_packet(sid, packet.Packet(packet.EVENT, namespace=namespace,
data=[event] + data, id=id))
def _send_packet(self, sid, pkt):
def _send_packet(self, eio_sid, pkt):
"""Send a Socket.IO packet to a client."""
encoded_packet = pkt.encode()
if isinstance(encoded_packet, list):
for ep in encoded_packet:
self.eio.send(sid, ep)
self.eio.send(eio_sid, ep)
else:
self.eio.send(sid, encoded_packet)
self.eio.send(eio_sid, encoded_packet)
def _handle_connect(self, sid, namespace):
def _handle_connect(self, eio_sid, namespace):
"""Handle a client connection request."""
namespace = namespace or '/'
self.manager.connect(sid, namespace)
sid = self.manager.connect(eio_sid, namespace)
if self.always_connect:
self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace))
self._send_packet(eio_sid, packet.Packet(
packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = None
try:
success = self._trigger_event('connect', namespace, sid,
self.environ[sid])
self.environ[eio_sid])
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args
success = False
@ -621,40 +619,34 @@ class Server(object):
if success is False:
if self.always_connect:
self.manager.pre_disconnect(sid, namespace)
self._send_packet(sid, packet.Packet(
self._send_packet(eio_sid, packet.Packet(
packet.DISCONNECT, data=fail_reason, namespace=namespace))
elif namespace != '/':
self._send_packet(sid, packet.Packet(
self._send_packet(eio_sid, packet.Packet(
packet.CONNECT_ERROR, data=fail_reason,
namespace=namespace))
self.manager.disconnect(sid, namespace)
if namespace == '/' and sid in self.environ: # pragma: no cover
del self.environ[sid]
if namespace == '/' and \
eio_sid in self.environ: # pragma: no cover
del self.environ[eio_sid]
return fail_reason or False
elif not self.always_connect:
self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace))
self._send_packet(eio_sid, packet.Packet(
packet.CONNECT, {'sid': sid}, namespace=namespace))
def _handle_disconnect(self, sid, namespace):
def _handle_disconnect(self, eio_sid, namespace):
"""Handle a client disconnect."""
namespace = namespace or '/'
if namespace == '/':
namespace_list = list(self.manager.get_namespaces())
else:
namespace_list = [namespace]
for n in namespace_list:
if n != '/' and self.manager.is_connected(sid, n):
self.manager.pre_disconnect(sid, namespace=n)
self._trigger_event('disconnect', n, sid)
self.manager.disconnect(sid, n)
if namespace == '/' and self.manager.is_connected(sid, namespace):
self.manager.pre_disconnect(sid, namespace='/')
self._trigger_event('disconnect', '/', sid)
self.manager.disconnect(sid, '/')
def _handle_event(self, sid, namespace, id, data):
sid = self.manager.sid_from_eio_sid(eio_sid, namespace)
if self.manager.is_connected(sid, namespace):
self.manager.pre_disconnect(sid, namespace=namespace)
self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace)
def _handle_event(self, eio_sid, namespace, id, data):
"""Handle an incoming client event."""
namespace = namespace or '/'
sid = self.manager.sid_from_eio_sid(eio_sid, namespace)
self.logger.info('received event "%s" from %s [%s]', data[0], sid,
namespace)
if not self.manager.is_connected(sid, namespace):
@ -663,11 +655,13 @@ class Server(object):
return
if self.async_handlers:
self.start_background_task(self._handle_event_internal, self, sid,
data, namespace, id)
eio_sid, data, namespace, id)
else:
self._handle_event_internal(self, sid, data, namespace, id)
self._handle_event_internal(self, sid, eio_sid, data, namespace,
id)
def _handle_event_internal(self, server, sid, data, namespace, id):
def _handle_event_internal(self, server, sid, eio_sid, data, namespace,
id):
r = server._trigger_event(data[0], namespace, sid, *data[1:])
if id is not None:
# send ACK packet with the response returned by the handler
@ -678,13 +672,13 @@ class Server(object):
data = list(r)
else:
data = [r]
server._send_packet(sid, packet.Packet(packet.ACK,
namespace=namespace,
id=id, data=data))
server._send_packet(eio_sid, packet.Packet(
packet.ACK, namespace=namespace, id=id, data=data))
def _handle_ack(self, sid, namespace, id, data):
def _handle_ack(self, eio_sid, namespace, id, data):
"""Handle ACK packets from the client."""
namespace = namespace or '/'
sid = self.manager.sid_from_eio_sid(eio_sid, namespace)
self.logger.info('received ack from %s [%s]', sid, namespace)
self.manager.trigger_callback(sid, namespace, id, data)
@ -699,46 +693,47 @@ class Server(object):
return self.namespace_handlers[namespace].trigger_event(
event, *args)
def _handle_eio_connect(self, sid, environ):
def _handle_eio_connect(self, eio_sid, environ):
"""Handle the Engine.IO connection event."""
if not self.manager_initialized:
self.manager_initialized = True
self.manager.initialize()
self.environ[sid] = environ
self.environ[eio_sid] = environ
def _handle_eio_message(self, sid, data):
def _handle_eio_message(self, eio_sid, data):
"""Dispatch Engine.IO messages."""
if sid in self._binary_packet:
pkt = self._binary_packet[sid]
if eio_sid in self._binary_packet:
pkt = self._binary_packet[eio_sid]
if pkt.add_attachment(data):
del self._binary_packet[sid]
del self._binary_packet[eio_sid]
if pkt.packet_type == packet.BINARY_EVENT:
self._handle_event(sid, pkt.namespace, pkt.id, pkt.data)
self._handle_event(eio_sid, pkt.namespace, pkt.id,
pkt.data)
else:
self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
self._handle_connect(sid, pkt.namespace)
self._handle_connect(eio_sid, pkt.namespace)
elif pkt.packet_type == packet.DISCONNECT:
self._handle_disconnect(sid, pkt.namespace)
self._handle_disconnect(eio_sid, pkt.namespace)
elif pkt.packet_type == packet.EVENT:
self._handle_event(sid, pkt.namespace, pkt.id, pkt.data)
self._handle_event(eio_sid, pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.ACK:
self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.BINARY_EVENT or \
pkt.packet_type == packet.BINARY_ACK:
self._binary_packet[sid] = pkt
self._binary_packet[eio_sid] = pkt
elif pkt.packet_type == packet.CONNECT_ERROR:
raise ValueError('Unexpected CONNECT_ERROR packet.')
else:
raise ValueError('Unknown packet type.')
def _handle_eio_disconnect(self, sid):
def _handle_eio_disconnect(self, eio_sid):
"""Handle Engine.IO disconnect event."""
self._handle_disconnect(sid, '/')
if sid in self.environ:
del self.environ[sid]
self._handle_disconnect(eio_sid, '/')
if eio_sid in self.environ:
del self.environ[eio_sid]
def _engineio_server_class(self):
return engineio.Server

183
tests/common/test_base_manager.py

@ -12,21 +12,28 @@ from socketio import base_manager
class TestBaseManager(unittest.TestCase):
def setUp(self):
id = 0
def generate_id():
nonlocal id
id += 1
return str(id)
mock_server = mock.MagicMock()
mock_server.eio.generate_id = generate_id
self.bm = base_manager.BaseManager()
self.bm.set_server(mock_server)
self.bm.initialize()
def test_connect(self):
self.bm.connect('123', '/foo')
sid = self.bm.connect('123', '/foo')
assert None in self.bm.rooms['/foo']
assert '123' in self.bm.rooms['/foo']
assert '123' in self.bm.rooms['/foo'][None]
assert '123' in self.bm.rooms['/foo']['123']
assert self.bm.rooms['/foo'] == {
None: {'123': True},
'123': {'123': True},
}
assert sid in self.bm.rooms['/foo']
assert sid in self.bm.rooms['/foo'][None]
assert sid in self.bm.rooms['/foo'][sid]
assert dict(self.bm.rooms['/foo'][None]) == {sid: '123'}
assert dict(self.bm.rooms['/foo'][sid]) == {sid: '123'}
assert self.bm.sid_from_eio_sid('123', '/foo') == sid
def test_pre_disconnect(self):
self.bm.connect('123', '/foo')
@ -43,63 +50,53 @@ class TestBaseManager(unittest.TestCase):
assert self.bm.pending_disconnect == {}
def test_disconnect(self):
self.bm.connect('123', '/foo')
self.bm.connect('456', '/foo')
self.bm.enter_room('123', '/foo', 'bar')
self.bm.enter_room('456', '/foo', 'baz')
self.bm.disconnect('123', '/foo')
assert self.bm.rooms['/foo'] == {
None: {'456': True},
'456': {'456': True},
'baz': {'456': True},
}
sid1 = self.bm.connect('123', '/foo')
sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.enter_room(sid2, '/foo', 'baz')
self.bm.disconnect(sid1, '/foo')
assert dict(self.bm.rooms['/foo'][None]) == {sid2: '456'}
assert dict(self.bm.rooms['/foo'][sid2]) == {sid2: '456'}
assert dict(self.bm.rooms['/foo']['baz']) == {sid2: '456'}
def test_disconnect_default_namespace(self):
self.bm.connect('123', '/')
self.bm.connect('123', '/foo')
self.bm.connect('456', '/')
self.bm.connect('456', '/foo')
assert self.bm.is_connected('123', '/')
assert self.bm.is_connected('123', '/foo')
self.bm.disconnect('123', '/')
assert not self.bm.is_connected('123', '/')
assert self.bm.is_connected('123', '/foo')
self.bm.disconnect('123', '/foo')
assert not self.bm.is_connected('123', '/foo')
assert self.bm.rooms['/'] == {
None: {'456': True},
'456': {'456': True},
}
assert self.bm.rooms['/foo'] == {
None: {'456': True},
'456': {'456': True},
}
sid1 = self.bm.connect('123', '/')
sid2 = self.bm.connect('123', '/foo')
sid3 = self.bm.connect('456', '/')
sid4 = self.bm.connect('456', '/foo')
assert self.bm.is_connected(sid1, '/')
assert self.bm.is_connected(sid2, '/foo')
self.bm.disconnect(sid1, '/')
assert not self.bm.is_connected(sid1, '/')
assert self.bm.is_connected(sid2, '/foo')
self.bm.disconnect(sid2, '/foo')
assert not self.bm.is_connected(sid2, '/foo')
assert dict(self.bm.rooms['/'][None]) == {sid3: '456'}
assert dict(self.bm.rooms['/'][sid3]) == {sid3: '456'}
assert dict(self.bm.rooms['/foo'][None]) == {sid4: '456'}
assert dict(self.bm.rooms['/foo'][sid4]) == {sid4: '456'}
def test_disconnect_twice(self):
self.bm.connect('123', '/')
self.bm.connect('123', '/foo')
self.bm.connect('456', '/')
self.bm.connect('456', '/foo')
self.bm.disconnect('123', '/')
self.bm.disconnect('123', '/foo')
self.bm.disconnect('123', '/')
self.bm.disconnect('123', '/foo')
assert self.bm.rooms['/'] == {
None: {'456': True},
'456': {'456': True},
}
assert self.bm.rooms['/foo'] == {
None: {'456': True},
'456': {'456': True},
}
sid1 = self.bm.connect('123', '/')
sid2 = self.bm.connect('123', '/foo')
sid3 = self.bm.connect('456', '/')
sid4 = self.bm.connect('456', '/foo')
self.bm.disconnect(sid1, '/')
self.bm.disconnect(sid2, '/foo')
self.bm.disconnect(sid1, '/')
self.bm.disconnect(sid2, '/foo')
assert dict(self.bm.rooms['/'][None]) == {sid3: '456'}
assert dict(self.bm.rooms['/'][sid3]) == {sid3: '456'}
assert dict(self.bm.rooms['/foo'][None]) == {sid4: '456'}
assert dict(self.bm.rooms['/foo'][sid4]) == {sid4: '456'}
def test_disconnect_all(self):
self.bm.connect('123', '/foo')
self.bm.connect('456', '/foo')
self.bm.enter_room('123', '/foo', 'bar')
self.bm.enter_room('456', '/foo', 'baz')
self.bm.disconnect('123', '/foo')
self.bm.disconnect('456', '/foo')
sid1 = self.bm.connect('123', '/foo')
sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.enter_room(sid2, '/foo', 'baz')
self.bm.disconnect(sid1, '/foo')
self.bm.disconnect(sid2, '/foo')
assert self.bm.rooms == {}
def test_disconnect_with_callbacks(self):
@ -152,12 +149,12 @@ class TestBaseManager(unittest.TestCase):
def test_get_participants(self):
self.bm.connect('123', '/')
self.bm.connect('456', '/')
self.bm.connect('789', '/')
self.bm.disconnect('789', '/')
assert '789' not in self.bm.rooms['/'][None]
sid = self.bm.connect('789', '/')
self.bm.disconnect(sid, '/')
assert sid not in self.bm.rooms['/'][None]
participants = list(self.bm.get_participants('/', None))
assert len(participants) == 2
assert '789' not in participants
assert sid not in participants
def test_leave_invalid_room(self):
self.bm.connect('123', '/foo')
@ -169,11 +166,11 @@ class TestBaseManager(unittest.TestCase):
assert [] == rooms
def test_close_room(self):
self.bm.connect('123', '/foo')
sid1 = self.bm.connect('123', '/foo')
self.bm.connect('456', '/foo')
self.bm.connect('789', '/foo')
self.bm.enter_room('123', '/foo', 'bar')
self.bm.enter_room('123', '/foo', 'bar')
self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.close_room('bar', '/foo')
assert 'bar' not in self.bm.rooms['/foo']
@ -181,26 +178,26 @@ class TestBaseManager(unittest.TestCase):
self.bm.close_room('bar', '/foo')
def test_rooms(self):
self.bm.connect('123', '/foo')
self.bm.enter_room('123', '/foo', 'bar')
r = self.bm.get_rooms('123', '/foo')
sid = self.bm.connect('123', '/foo')
self.bm.enter_room(sid, '/foo', 'bar')
r = self.bm.get_rooms(sid, '/foo')
assert len(r) == 2
assert '123' in r
assert sid in r
assert 'bar' in r
def test_emit_to_sid(self):
self.bm.connect('123', '/foo')
sid = self.bm.connect('123', '/foo')
self.bm.connect('456', '/foo')
self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo', room='123')
self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo', room=sid)
self.bm.server._emit_internal.assert_called_once_with(
'123', 'my event', {'foo': 'bar'}, '/foo', None
)
def test_emit_to_room(self):
self.bm.connect('123', '/foo')
self.bm.enter_room('123', '/foo', 'bar')
self.bm.connect('456', '/foo')
self.bm.enter_room('456', '/foo', 'bar')
sid1 = self.bm.connect('123', '/foo')
self.bm.enter_room(sid1, '/foo', 'bar')
sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room(sid2, '/foo', 'bar')
self.bm.connect('789', '/foo')
self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo', room='bar')
assert self.bm.server._emit_internal.call_count == 2
@ -212,10 +209,10 @@ class TestBaseManager(unittest.TestCase):
)
def test_emit_to_all(self):
self.bm.connect('123', '/foo')
self.bm.enter_room('123', '/foo', 'bar')
self.bm.connect('456', '/foo')
self.bm.enter_room('456', '/foo', 'bar')
sid1 = self.bm.connect('123', '/foo')
self.bm.enter_room(sid1, '/foo', 'bar')
sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room(sid2, '/foo', 'bar')
self.bm.connect('789', '/foo')
self.bm.connect('abc', '/bar')
self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo')
@ -231,14 +228,14 @@ class TestBaseManager(unittest.TestCase):
)
def test_emit_to_all_skip_one(self):
self.bm.connect('123', '/foo')
self.bm.enter_room('123', '/foo', 'bar')
self.bm.connect('456', '/foo')
self.bm.enter_room('456', '/foo', 'bar')
sid1 = self.bm.connect('123', '/foo')
self.bm.enter_room(sid1, '/foo', 'bar')
sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room(sid2, '/foo', 'bar')
self.bm.connect('789', '/foo')
self.bm.connect('abc', '/bar')
self.bm.emit(
'my event', {'foo': 'bar'}, namespace='/foo', skip_sid='456'
'my event', {'foo': 'bar'}, namespace='/foo', skip_sid=sid2
)
assert self.bm.server._emit_internal.call_count == 2
self.bm.server._emit_internal.assert_any_call(
@ -249,17 +246,17 @@ class TestBaseManager(unittest.TestCase):
)
def test_emit_to_all_skip_two(self):
self.bm.connect('123', '/foo')
self.bm.enter_room('123', '/foo', 'bar')
self.bm.connect('456', '/foo')
self.bm.enter_room('456', '/foo', 'bar')
self.bm.connect('789', '/foo')
sid1 = self.bm.connect('123', '/foo')
self.bm.enter_room(sid1, '/foo', 'bar')
sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room(sid2, '/foo', 'bar')
sid3 = self.bm.connect('789', '/foo')
self.bm.connect('abc', '/bar')
self.bm.emit(
'my event',
{'foo': 'bar'},
namespace='/foo',
skip_sid=['123', '789'],
skip_sid=[sid1, sid3],
)
assert self.bm.server._emit_internal.call_count == 1
self.bm.server._emit_internal.assert_any_call(
@ -267,13 +264,13 @@ class TestBaseManager(unittest.TestCase):
)
def test_emit_with_callback(self):
self.bm.connect('123', '/foo')
sid = self.bm.connect('123', '/foo')
self.bm._generate_ack_id = mock.MagicMock()
self.bm._generate_ack_id.return_value = 11
self.bm.emit(
'my event', {'foo': 'bar'}, namespace='/foo', callback='cb'
)
self.bm._generate_ack_id.assert_called_once_with('123', '/foo', 'cb')
self.bm._generate_ack_id.assert_called_once_with(sid, '/foo', 'cb')
self.bm.server._emit_internal.assert_called_once_with(
'123', 'my event', {'foo': 'bar'}, '/foo', 11
)

Loading…
Cancel
Save