Browse Source

v5 protocol: handle per-namespace sids in base manager

pull/599/head
Miguel Grinberg 4 years ago
parent
commit
308b0c8eeb
No known key found for this signature in database GPG Key ID: 36848B262DF5F06C
  1. 1
      setup.py
  2. 4
      socketio/asyncio_client.py
  3. 4
      socketio/asyncio_manager.py
  4. 114
      socketio/asyncio_server.py
  5. 36
      socketio/base_manager.py
  6. 4
      socketio/client.py
  7. 113
      socketio/server.py
  8. 183
      tests/common/test_base_manager.py

1
setup.py

@ -30,6 +30,7 @@ setup(
platforms='any', platforms='any',
install_requires=[ install_requires=[
'six>=1.9.0', 'six>=1.9.0',
'bidict>=0.21.0',
'python-engineio>=3.13.0,<4' 'python-engineio>=3.13.0,<4'
], ],
extras_require={ extras_require={

4
socketio/asyncio_client.py

@ -353,7 +353,7 @@ class AsyncClient(client.Client):
if namespace in self.namespaces: if namespace in self.namespaces:
self.namespaces.remove(namespace) self.namespaces.remove(namespace)
if namespace == '/': if namespace == '/':
self.namespaces = [] self.namespaces = {}
self.connected = False self.connected = False
async def _trigger_event(self, event, namespace, *args): async def _trigger_event(self, event, namespace, *args):
@ -456,7 +456,7 @@ class AsyncClient(client.Client):
if self.connected: if self.connected:
for n in self.namespaces: for n in self.namespaces:
await self._trigger_event('disconnect', namespace=n) await self._trigger_event('disconnect', namespace=n)
self.namespaces = [] self.namespaces = {}
self.connected = False self.connected = False
self.callbacks = {} self.callbacks = {}
self._binary_packet = None self._binary_packet = None

4
socketio/asyncio_manager.py

@ -20,13 +20,13 @@ class AsyncManager(BaseManager):
tasks = [] tasks = []
if not isinstance(skip_sid, list): if not isinstance(skip_sid, list):
skip_sid = [skip_sid] skip_sid = [skip_sid]
for sid in self.get_participants(namespace, room): for sid, eio_sid in self.get_participants(namespace, room):
if sid not in skip_sid: if sid not in skip_sid:
if callback is not None: if callback is not None:
id = self._generate_ack_id(sid, namespace, callback) id = self._generate_ack_id(sid, namespace, callback)
else: else:
id = None id = None
tasks.append(self.server._emit_internal(sid, event, data, tasks.append(self.server._emit_internal(eio_sid, event, data,
namespace, id)) namespace, id))
if tasks == []: # pragma: no cover if tasks == []: # pragma: no cover
return return

114
socketio/asyncio_server.py

@ -336,13 +336,11 @@ class AsyncServer(server.Server):
delete_it = await self.manager.can_disconnect(sid, namespace) delete_it = await self.manager.can_disconnect(sid, namespace)
if delete_it: if delete_it:
self.logger.info('Disconnecting %s [%s]', sid, namespace) self.logger.info('Disconnecting %s [%s]', sid, namespace)
self.manager.pre_disconnect(sid, namespace=namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace)
await self._send_packet(sid, packet.Packet(packet.DISCONNECT, await self._send_packet(eio_sid, packet.Packet(
namespace=namespace)) packet.DISCONNECT, namespace=namespace))
await self._trigger_event('disconnect', namespace, sid) await self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace) self.manager.disconnect(sid, namespace=namespace)
if namespace == '/':
await self.eio.disconnect(sid)
async def handle_request(self, *args, **kwargs): async def handle_request(self, *args, **kwargs):
"""Handle an HTTP request from the client. """Handle an HTTP request from the client.
@ -396,26 +394,26 @@ class AsyncServer(server.Server):
await self._send_packet(sid, packet.Packet( await self._send_packet(sid, packet.Packet(
packet.EVENT, namespace=namespace, data=[event] + data, id=id)) packet.EVENT, namespace=namespace, data=[event] + data, id=id))
async def _send_packet(self, sid, pkt): async def _send_packet(self, eio_sid, pkt):
"""Send a Socket.IO packet to a client.""" """Send a Socket.IO packet to a client."""
encoded_packet = pkt.encode() encoded_packet = pkt.encode()
if isinstance(encoded_packet, list): if isinstance(encoded_packet, list):
for ep in encoded_packet: for ep in encoded_packet:
await self.eio.send(sid, ep) await self.eio.send(eio_sid, ep)
else: else:
await self.eio.send(sid, encoded_packet) await self.eio.send(eio_sid, encoded_packet)
async def _handle_connect(self, sid, namespace): async def _handle_connect(self, eio_sid, namespace):
"""Handle a client connection request.""" """Handle a client connection request."""
namespace = namespace or '/' namespace = namespace or '/'
self.manager.connect(sid, namespace) sid = self.manager.connect(eio_sid, namespace)
if self.always_connect: if self.always_connect:
await self._send_packet(sid, packet.Packet(packet.CONNECT, await self._send_packet(eio_sid, packet.Packet(
namespace=namespace)) packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = None fail_reason = None
try: try:
success = await self._trigger_event('connect', namespace, sid, success = await self._trigger_event('connect', namespace, sid,
self.environ[sid]) self.environ[eio_sid])
except exceptions.ConnectionRefusedError as exc: except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args fail_reason = exc.error_args
success = False success = False
@ -423,40 +421,34 @@ class AsyncServer(server.Server):
if success is False: if success is False:
if self.always_connect: if self.always_connect:
self.manager.pre_disconnect(sid, namespace) self.manager.pre_disconnect(sid, namespace)
await self._send_packet(sid, packet.Packet( await self._send_packet(eio_sid, packet.Packet(
packet.DISCONNECT, data=fail_reason, namespace=namespace)) packet.DISCONNECT, data=fail_reason, namespace=namespace))
elif namespace != '/': elif namespace != '/':
await self._send_packet(sid, packet.Packet( await self._send_packet(eio_sid, packet.Packet(
packet.CONNECT_ERROR, data=fail_reason, packet.CONNECT_ERROR, data=fail_reason,
namespace=namespace)) namespace=namespace))
self.manager.disconnect(sid, namespace) self.manager.disconnect(sid, namespace)
if namespace == '/' and sid in self.environ: # pragma: no cover if namespace == '/' and \
del self.environ[sid] eio_sid in self.environ: # pragma: no cover
del self.environ[eio_sid]
return fail_reason or False return fail_reason or False
elif not self.always_connect: elif not self.always_connect:
await self._send_packet(sid, packet.Packet(packet.CONNECT, await self._send_packet(eio_sid, packet.Packet(
namespace=namespace)) packet.CONNECT, {'sid': sid}, namespace=namespace))
async def _handle_disconnect(self, sid, namespace): async def _handle_disconnect(self, eio_sid, namespace):
"""Handle a client disconnect.""" """Handle a client disconnect."""
namespace = namespace or '/' namespace = namespace or '/'
if namespace == '/': sid = self.manager.sid_from_eio_sid(eio_sid, namespace)
namespace_list = list(self.manager.get_namespaces()) if self.manager.is_connected(sid, namespace):
else: self.manager.pre_disconnect(sid, namespace=namespace)
namespace_list = [namespace] await self._trigger_event('disconnect', namespace, sid)
for n in namespace_list: self.manager.disconnect(sid, namespace)
if n != '/' and self.manager.is_connected(sid, n):
self.manager.pre_disconnect(sid, namespace=n) async def _handle_event(self, eio_sid, namespace, id, data):
await self._trigger_event('disconnect', n, sid)
self.manager.disconnect(sid, n)
if namespace == '/' and self.manager.is_connected(sid, namespace):
self.manager.pre_disconnect(sid, namespace='/')
await self._trigger_event('disconnect', '/', sid)
self.manager.disconnect(sid, '/')
async def _handle_event(self, sid, namespace, id, data):
"""Handle an incoming client event.""" """Handle an incoming client event."""
namespace = namespace or '/' namespace = namespace or '/'
sid = self.manager.sid_from_eio_sid(eio_sid, namespace)
self.logger.info('received event "%s" from %s [%s]', data[0], sid, self.logger.info('received event "%s" from %s [%s]', data[0], sid,
namespace) namespace)
if not self.manager.is_connected(sid, namespace): if not self.manager.is_connected(sid, namespace):
@ -465,11 +457,13 @@ class AsyncServer(server.Server):
return return
if self.async_handlers: if self.async_handlers:
self.start_background_task(self._handle_event_internal, self, sid, self.start_background_task(self._handle_event_internal, self, sid,
data, namespace, id) eio_sid, data, namespace, id)
else: else:
await self._handle_event_internal(self, sid, data, namespace, id) await self._handle_event_internal(self, sid, eio_sid, data,
namespace, id)
async def _handle_event_internal(self, server, sid, data, namespace, id): async def _handle_event_internal(self, server, sid, eio_sid, data,
namespace, id):
r = await server._trigger_event(data[0], namespace, sid, *data[1:]) r = await server._trigger_event(data[0], namespace, sid, *data[1:])
if id is not None: if id is not None:
# send ACK packet with the response returned by the handler # send ACK packet with the response returned by the handler
@ -480,13 +474,13 @@ class AsyncServer(server.Server):
data = list(r) data = list(r)
else: else:
data = [r] data = [r]
await server._send_packet(sid, packet.Packet(packet.ACK, await server._send_packet(eio_sid, packet.Packet(
namespace=namespace, packet.ACK, namespace=namespace, id=id, data=data))
id=id, data=data))
async def _handle_ack(self, sid, namespace, id, data): async def _handle_ack(self, eio_sid, namespace, id, data):
"""Handle ACK packets from the client.""" """Handle ACK packets from the client."""
namespace = namespace or '/' namespace = namespace or '/'
sid = self.manager.sid_from_eio_sid(eio_sid, namespace)
self.logger.info('received ack from %s [%s]', sid, namespace) self.logger.info('received ack from %s [%s]', sid, namespace)
await self.manager.trigger_callback(sid, namespace, id, data) await self.manager.trigger_callback(sid, namespace, id, data)
@ -509,48 +503,50 @@ class AsyncServer(server.Server):
return await self.namespace_handlers[namespace].trigger_event( return await self.namespace_handlers[namespace].trigger_event(
event, *args) event, *args)
async def _handle_eio_connect(self, sid, environ): async def _handle_eio_connect(self, eio_sid, environ):
"""Handle the Engine.IO connection event.""" """Handle the Engine.IO connection event."""
if not self.manager_initialized: if not self.manager_initialized:
self.manager_initialized = True self.manager_initialized = True
self.manager.initialize() self.manager.initialize()
self.environ[sid] = environ self.environ[eio_sid] = environ
async def _handle_eio_message(self, sid, data): async def _handle_eio_message(self, eio_sid, data):
"""Dispatch Engine.IO messages.""" """Dispatch Engine.IO messages."""
if sid in self._binary_packet: if eio_sid in self._binary_packet:
pkt = self._binary_packet[sid] pkt = self._binary_packet[eio_sid]
if pkt.add_attachment(data): if pkt.add_attachment(data):
del self._binary_packet[sid] del self._binary_packet[eio_sid]
if pkt.packet_type == packet.BINARY_EVENT: if pkt.packet_type == packet.BINARY_EVENT:
await self._handle_event(sid, pkt.namespace, pkt.id, await self._handle_event(eio_sid, pkt.namespace, pkt.id,
pkt.data) pkt.data)
else: else:
await self._handle_ack(sid, pkt.namespace, pkt.id, await self._handle_ack(eio_sid, pkt.namespace, pkt.id,
pkt.data) pkt.data)
else: else:
pkt = packet.Packet(encoded_packet=data) pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT: if pkt.packet_type == packet.CONNECT:
await self._handle_connect(sid, pkt.namespace) await self._handle_connect(eio_sid, pkt.namespace)
elif pkt.packet_type == packet.DISCONNECT: elif pkt.packet_type == packet.DISCONNECT:
await self._handle_disconnect(sid, pkt.namespace) await self._handle_disconnect(eio_sid, pkt.namespace)
elif pkt.packet_type == packet.EVENT: elif pkt.packet_type == packet.EVENT:
await self._handle_event(sid, pkt.namespace, pkt.id, pkt.data) await self._handle_event(eio_sid, pkt.namespace, pkt.id,
pkt.data)
elif pkt.packet_type == packet.ACK: elif pkt.packet_type == packet.ACK:
await self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) await self._handle_ack(eio_sid, pkt.namespace, pkt.id,
pkt.data)
elif pkt.packet_type == packet.BINARY_EVENT or \ elif pkt.packet_type == packet.BINARY_EVENT or \
pkt.packet_type == packet.BINARY_ACK: pkt.packet_type == packet.BINARY_ACK:
self._binary_packet[sid] = pkt self._binary_packet[eio_sid] = pkt
elif pkt.packet_type == packet.CONNECT_ERROR: elif pkt.packet_type == packet.CONNECT_ERROR:
raise ValueError('Unexpected CONNECT_ERROR packet.') raise ValueError('Unexpected CONNECT_ERROR packet.')
else: else:
raise ValueError('Unknown packet type.') raise ValueError('Unknown packet type.')
async def _handle_eio_disconnect(self, sid): async def _handle_eio_disconnect(self, eio_sid):
"""Handle Engine.IO disconnect event.""" """Handle Engine.IO disconnect event."""
await self._handle_disconnect(sid, '/') await self._handle_disconnect(eio_sid, '/')
if sid in self.environ: if eio_sid in self.environ:
del self.environ[sid] del self.environ[eio_sid]
def _engineio_server_class(self): def _engineio_server_class(self):
return engineio.AsyncServer return engineio.AsyncServer

36
socketio/base_manager.py

@ -1,6 +1,7 @@
import itertools import itertools
import logging import logging
from bidict import bidict
import six import six
default_logger = logging.getLogger('socketio') default_logger = logging.getLogger('socketio')
@ -18,7 +19,8 @@ class BaseManager(object):
def __init__(self): def __init__(self):
self.logger = None self.logger = None
self.server = None self.server = None
self.rooms = {} self.rooms = {} # self.rooms[namespace][room][sio_sid] = eio_sid
self.eio_to_sid = {}
self.callbacks = {} self.callbacks = {}
self.pending_disconnect = {} self.pending_disconnect = {}
@ -37,13 +39,15 @@ class BaseManager(object):
def get_participants(self, namespace, room): def get_participants(self, namespace, room):
"""Return an iterable with the active participants in a room.""" """Return an iterable with the active participants in a room."""
for sid, active in six.iteritems(self.rooms[namespace][room].copy()): for sid, eio_sid in self.rooms[namespace][room].copy().items():
yield sid yield sid, eio_sid
def connect(self, sid, namespace): def connect(self, eio_sid, namespace):
"""Register a client connection to a namespace.""" """Register a client connection to a namespace."""
self.enter_room(sid, namespace, None) sid = self.server.eio.generate_id()
self.enter_room(sid, namespace, sid) self.enter_room(sid, namespace, None, eio_sid=eio_sid)
self.enter_room(sid, namespace, sid, eio_sid=eio_sid)
return sid
def is_connected(self, sid, namespace): def is_connected(self, sid, namespace):
if namespace in self.pending_disconnect and \ if namespace in self.pending_disconnect and \
@ -55,6 +59,9 @@ class BaseManager(object):
except KeyError: except KeyError:
pass pass
def sid_from_eio_sid(self, eio_sid, namespace):
return self.rooms[namespace][None].inverse.get(eio_sid)
def can_disconnect(self, sid, namespace): def can_disconnect(self, sid, namespace):
return self.is_connected(sid, namespace) return self.is_connected(sid, namespace)
@ -68,6 +75,7 @@ class BaseManager(object):
if namespace not in self.pending_disconnect: if namespace not in self.pending_disconnect:
self.pending_disconnect[namespace] = [] self.pending_disconnect[namespace] = []
self.pending_disconnect[namespace].append(sid) self.pending_disconnect[namespace].append(sid)
return self.rooms[namespace][None].get(sid)
def disconnect(self, sid, namespace): def disconnect(self, sid, namespace):
"""Register a client disconnect from a namespace.""" """Register a client disconnect from a namespace."""
@ -89,13 +97,15 @@ class BaseManager(object):
if len(self.pending_disconnect[namespace]) == 0: if len(self.pending_disconnect[namespace]) == 0:
del self.pending_disconnect[namespace] del self.pending_disconnect[namespace]
def enter_room(self, sid, namespace, room): def enter_room(self, sid, namespace, room, eio_sid=None):
"""Add a client to a room.""" """Add a client to a room."""
if namespace not in self.rooms: if namespace not in self.rooms:
self.rooms[namespace] = {} self.rooms[namespace] = {}
if room not in self.rooms[namespace]: if room not in self.rooms[namespace]:
self.rooms[namespace][room] = {} self.rooms[namespace][room] = bidict()
self.rooms[namespace][room][sid] = True if eio_sid is None:
eio_sid = self.rooms[namespace][None][sid]
self.rooms[namespace][room][sid] = eio_sid
def leave_room(self, sid, namespace, room): def leave_room(self, sid, namespace, room):
"""Remove a client from a room.""" """Remove a client from a room."""
@ -111,7 +121,7 @@ class BaseManager(object):
def close_room(self, room, namespace): def close_room(self, room, namespace):
"""Remove all participants from a room.""" """Remove all participants from a room."""
try: try:
for sid in self.get_participants(namespace, room): for sid, _ in self.get_participants(namespace, room):
self.leave_room(sid, namespace, room) self.leave_room(sid, namespace, room)
except KeyError: except KeyError:
pass pass
@ -121,7 +131,7 @@ class BaseManager(object):
r = [] r = []
try: try:
for room_name, room in six.iteritems(self.rooms[namespace]): for room_name, room in six.iteritems(self.rooms[namespace]):
if room_name is not None and sid in room and room[sid]: if room_name is not None and sid in room:
r.append(room_name) r.append(room_name)
except KeyError: except KeyError:
pass pass
@ -135,13 +145,13 @@ class BaseManager(object):
return return
if not isinstance(skip_sid, list): if not isinstance(skip_sid, list):
skip_sid = [skip_sid] skip_sid = [skip_sid]
for sid in self.get_participants(namespace, room): for sid, eio_sid in self.get_participants(namespace, room):
if sid not in skip_sid: if sid not in skip_sid:
if callback is not None: if callback is not None:
id = self._generate_ack_id(sid, namespace, callback) id = self._generate_ack_id(sid, namespace, callback)
else: else:
id = None id = None
self.server._emit_internal(sid, event, data, namespace, id) self.server._emit_internal(eio_sid, event, data, namespace, id)
def trigger_callback(self, sid, namespace, id, data): def trigger_callback(self, sid, namespace, id, data):
"""Invoke an application callback.""" """Invoke an application callback."""

4
socketio/client.py

@ -533,7 +533,7 @@ class Client(object):
if namespace in self.namespaces: if namespace in self.namespaces:
del self.namespaces[namespace] del self.namespaces[namespace]
if namespace == '/': if namespace == '/':
self.namespaces = [] self.namespaces = {}
self.connected = False self.connected = False
def _trigger_event(self, event, namespace, *args): def _trigger_event(self, event, namespace, *args):
@ -625,7 +625,7 @@ class Client(object):
if self.connected: if self.connected:
for n in self.namespaces: for n in self.namespaces:
self._trigger_event('disconnect', namespace=n) self._trigger_event('disconnect', namespace=n)
self.namespaces = [] self.namespaces = {}
self.connected = False self.connected = False
self.callbacks = {} self.callbacks = {}
self._binary_packet = None self._binary_packet = None

113
socketio/server.py

@ -520,13 +520,11 @@ class Server(object):
delete_it = self.manager.can_disconnect(sid, namespace) delete_it = self.manager.can_disconnect(sid, namespace)
if delete_it: if delete_it:
self.logger.info('Disconnecting %s [%s]', sid, namespace) self.logger.info('Disconnecting %s [%s]', sid, namespace)
self.manager.pre_disconnect(sid, namespace=namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace)
self._send_packet(sid, packet.Packet(packet.DISCONNECT, self._send_packet(eio_sid, packet.Packet(
namespace=namespace)) packet.DISCONNECT, namespace=namespace))
self._trigger_event('disconnect', namespace, sid) self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace) self.manager.disconnect(sid, namespace=namespace)
if namespace == '/':
self.eio.disconnect(sid)
def transport(self, sid): def transport(self, sid):
"""Return the name of the transport used by the client. """Return the name of the transport used by the client.
@ -594,26 +592,26 @@ class Server(object):
self._send_packet(sid, packet.Packet(packet.EVENT, namespace=namespace, self._send_packet(sid, packet.Packet(packet.EVENT, namespace=namespace,
data=[event] + data, id=id)) data=[event] + data, id=id))
def _send_packet(self, sid, pkt): def _send_packet(self, eio_sid, pkt):
"""Send a Socket.IO packet to a client.""" """Send a Socket.IO packet to a client."""
encoded_packet = pkt.encode() encoded_packet = pkt.encode()
if isinstance(encoded_packet, list): if isinstance(encoded_packet, list):
for ep in encoded_packet: for ep in encoded_packet:
self.eio.send(sid, ep) self.eio.send(eio_sid, ep)
else: else:
self.eio.send(sid, encoded_packet) self.eio.send(eio_sid, encoded_packet)
def _handle_connect(self, sid, namespace): def _handle_connect(self, eio_sid, namespace):
"""Handle a client connection request.""" """Handle a client connection request."""
namespace = namespace or '/' namespace = namespace or '/'
self.manager.connect(sid, namespace) sid = self.manager.connect(eio_sid, namespace)
if self.always_connect: if self.always_connect:
self._send_packet(sid, packet.Packet(packet.CONNECT, self._send_packet(eio_sid, packet.Packet(
namespace=namespace)) packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = None fail_reason = None
try: try:
success = self._trigger_event('connect', namespace, sid, success = self._trigger_event('connect', namespace, sid,
self.environ[sid]) self.environ[eio_sid])
except exceptions.ConnectionRefusedError as exc: except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args fail_reason = exc.error_args
success = False success = False
@ -621,40 +619,34 @@ class Server(object):
if success is False: if success is False:
if self.always_connect: if self.always_connect:
self.manager.pre_disconnect(sid, namespace) self.manager.pre_disconnect(sid, namespace)
self._send_packet(sid, packet.Packet( self._send_packet(eio_sid, packet.Packet(
packet.DISCONNECT, data=fail_reason, namespace=namespace)) packet.DISCONNECT, data=fail_reason, namespace=namespace))
elif namespace != '/': elif namespace != '/':
self._send_packet(sid, packet.Packet( self._send_packet(eio_sid, packet.Packet(
packet.CONNECT_ERROR, data=fail_reason, packet.CONNECT_ERROR, data=fail_reason,
namespace=namespace)) namespace=namespace))
self.manager.disconnect(sid, namespace) self.manager.disconnect(sid, namespace)
if namespace == '/' and sid in self.environ: # pragma: no cover if namespace == '/' and \
del self.environ[sid] eio_sid in self.environ: # pragma: no cover
del self.environ[eio_sid]
return fail_reason or False return fail_reason or False
elif not self.always_connect: elif not self.always_connect:
self._send_packet(sid, packet.Packet(packet.CONNECT, self._send_packet(eio_sid, packet.Packet(
namespace=namespace)) packet.CONNECT, {'sid': sid}, namespace=namespace))
def _handle_disconnect(self, sid, namespace): def _handle_disconnect(self, eio_sid, namespace):
"""Handle a client disconnect.""" """Handle a client disconnect."""
namespace = namespace or '/' namespace = namespace or '/'
if namespace == '/': sid = self.manager.sid_from_eio_sid(eio_sid, namespace)
namespace_list = list(self.manager.get_namespaces()) if self.manager.is_connected(sid, namespace):
else: self.manager.pre_disconnect(sid, namespace=namespace)
namespace_list = [namespace] self._trigger_event('disconnect', namespace, sid)
for n in namespace_list: self.manager.disconnect(sid, namespace)
if n != '/' and self.manager.is_connected(sid, n):
self.manager.pre_disconnect(sid, namespace=n) def _handle_event(self, eio_sid, namespace, id, data):
self._trigger_event('disconnect', n, sid)
self.manager.disconnect(sid, n)
if namespace == '/' and self.manager.is_connected(sid, namespace):
self.manager.pre_disconnect(sid, namespace='/')
self._trigger_event('disconnect', '/', sid)
self.manager.disconnect(sid, '/')
def _handle_event(self, sid, namespace, id, data):
"""Handle an incoming client event.""" """Handle an incoming client event."""
namespace = namespace or '/' namespace = namespace or '/'
sid = self.manager.sid_from_eio_sid(eio_sid, namespace)
self.logger.info('received event "%s" from %s [%s]', data[0], sid, self.logger.info('received event "%s" from %s [%s]', data[0], sid,
namespace) namespace)
if not self.manager.is_connected(sid, namespace): if not self.manager.is_connected(sid, namespace):
@ -663,11 +655,13 @@ class Server(object):
return return
if self.async_handlers: if self.async_handlers:
self.start_background_task(self._handle_event_internal, self, sid, self.start_background_task(self._handle_event_internal, self, sid,
data, namespace, id) eio_sid, data, namespace, id)
else: else:
self._handle_event_internal(self, sid, data, namespace, id) self._handle_event_internal(self, sid, eio_sid, data, namespace,
id)
def _handle_event_internal(self, server, sid, data, namespace, id): def _handle_event_internal(self, server, sid, eio_sid, data, namespace,
id):
r = server._trigger_event(data[0], namespace, sid, *data[1:]) r = server._trigger_event(data[0], namespace, sid, *data[1:])
if id is not None: if id is not None:
# send ACK packet with the response returned by the handler # send ACK packet with the response returned by the handler
@ -678,13 +672,13 @@ class Server(object):
data = list(r) data = list(r)
else: else:
data = [r] data = [r]
server._send_packet(sid, packet.Packet(packet.ACK, server._send_packet(eio_sid, packet.Packet(
namespace=namespace, packet.ACK, namespace=namespace, id=id, data=data))
id=id, data=data))
def _handle_ack(self, sid, namespace, id, data): def _handle_ack(self, eio_sid, namespace, id, data):
"""Handle ACK packets from the client.""" """Handle ACK packets from the client."""
namespace = namespace or '/' namespace = namespace or '/'
sid = self.manager.sid_from_eio_sid(eio_sid, namespace)
self.logger.info('received ack from %s [%s]', sid, namespace) self.logger.info('received ack from %s [%s]', sid, namespace)
self.manager.trigger_callback(sid, namespace, id, data) self.manager.trigger_callback(sid, namespace, id, data)
@ -699,46 +693,47 @@ class Server(object):
return self.namespace_handlers[namespace].trigger_event( return self.namespace_handlers[namespace].trigger_event(
event, *args) event, *args)
def _handle_eio_connect(self, sid, environ): def _handle_eio_connect(self, eio_sid, environ):
"""Handle the Engine.IO connection event.""" """Handle the Engine.IO connection event."""
if not self.manager_initialized: if not self.manager_initialized:
self.manager_initialized = True self.manager_initialized = True
self.manager.initialize() self.manager.initialize()
self.environ[sid] = environ self.environ[eio_sid] = environ
def _handle_eio_message(self, sid, data): def _handle_eio_message(self, eio_sid, data):
"""Dispatch Engine.IO messages.""" """Dispatch Engine.IO messages."""
if sid in self._binary_packet: if eio_sid in self._binary_packet:
pkt = self._binary_packet[sid] pkt = self._binary_packet[eio_sid]
if pkt.add_attachment(data): if pkt.add_attachment(data):
del self._binary_packet[sid] del self._binary_packet[eio_sid]
if pkt.packet_type == packet.BINARY_EVENT: if pkt.packet_type == packet.BINARY_EVENT:
self._handle_event(sid, pkt.namespace, pkt.id, pkt.data) self._handle_event(eio_sid, pkt.namespace, pkt.id,
pkt.data)
else: else:
self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data)
else: else:
pkt = packet.Packet(encoded_packet=data) pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT: if pkt.packet_type == packet.CONNECT:
self._handle_connect(sid, pkt.namespace) self._handle_connect(eio_sid, pkt.namespace)
elif pkt.packet_type == packet.DISCONNECT: elif pkt.packet_type == packet.DISCONNECT:
self._handle_disconnect(sid, pkt.namespace) self._handle_disconnect(eio_sid, pkt.namespace)
elif pkt.packet_type == packet.EVENT: elif pkt.packet_type == packet.EVENT:
self._handle_event(sid, pkt.namespace, pkt.id, pkt.data) self._handle_event(eio_sid, pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.ACK: elif pkt.packet_type == packet.ACK:
self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.BINARY_EVENT or \ elif pkt.packet_type == packet.BINARY_EVENT or \
pkt.packet_type == packet.BINARY_ACK: pkt.packet_type == packet.BINARY_ACK:
self._binary_packet[sid] = pkt self._binary_packet[eio_sid] = pkt
elif pkt.packet_type == packet.CONNECT_ERROR: elif pkt.packet_type == packet.CONNECT_ERROR:
raise ValueError('Unexpected CONNECT_ERROR packet.') raise ValueError('Unexpected CONNECT_ERROR packet.')
else: else:
raise ValueError('Unknown packet type.') raise ValueError('Unknown packet type.')
def _handle_eio_disconnect(self, sid): def _handle_eio_disconnect(self, eio_sid):
"""Handle Engine.IO disconnect event.""" """Handle Engine.IO disconnect event."""
self._handle_disconnect(sid, '/') self._handle_disconnect(eio_sid, '/')
if sid in self.environ: if eio_sid in self.environ:
del self.environ[sid] del self.environ[eio_sid]
def _engineio_server_class(self): def _engineio_server_class(self):
return engineio.Server return engineio.Server

183
tests/common/test_base_manager.py

@ -12,21 +12,28 @@ from socketio import base_manager
class TestBaseManager(unittest.TestCase): class TestBaseManager(unittest.TestCase):
def setUp(self): def setUp(self):
id = 0
def generate_id():
nonlocal id
id += 1
return str(id)
mock_server = mock.MagicMock() mock_server = mock.MagicMock()
mock_server.eio.generate_id = generate_id
self.bm = base_manager.BaseManager() self.bm = base_manager.BaseManager()
self.bm.set_server(mock_server) self.bm.set_server(mock_server)
self.bm.initialize() self.bm.initialize()
def test_connect(self): def test_connect(self):
self.bm.connect('123', '/foo') sid = self.bm.connect('123', '/foo')
assert None in self.bm.rooms['/foo'] assert None in self.bm.rooms['/foo']
assert '123' in self.bm.rooms['/foo'] assert sid in self.bm.rooms['/foo']
assert '123' in self.bm.rooms['/foo'][None] assert sid in self.bm.rooms['/foo'][None]
assert '123' in self.bm.rooms['/foo']['123'] assert sid in self.bm.rooms['/foo'][sid]
assert self.bm.rooms['/foo'] == { assert dict(self.bm.rooms['/foo'][None]) == {sid: '123'}
None: {'123': True}, assert dict(self.bm.rooms['/foo'][sid]) == {sid: '123'}
'123': {'123': True}, assert self.bm.sid_from_eio_sid('123', '/foo') == sid
}
def test_pre_disconnect(self): def test_pre_disconnect(self):
self.bm.connect('123', '/foo') self.bm.connect('123', '/foo')
@ -43,63 +50,53 @@ class TestBaseManager(unittest.TestCase):
assert self.bm.pending_disconnect == {} assert self.bm.pending_disconnect == {}
def test_disconnect(self): def test_disconnect(self):
self.bm.connect('123', '/foo') sid1 = self.bm.connect('123', '/foo')
self.bm.connect('456', '/foo') sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room('123', '/foo', 'bar') self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.enter_room('456', '/foo', 'baz') self.bm.enter_room(sid2, '/foo', 'baz')
self.bm.disconnect('123', '/foo') self.bm.disconnect(sid1, '/foo')
assert self.bm.rooms['/foo'] == { assert dict(self.bm.rooms['/foo'][None]) == {sid2: '456'}
None: {'456': True}, assert dict(self.bm.rooms['/foo'][sid2]) == {sid2: '456'}
'456': {'456': True}, assert dict(self.bm.rooms['/foo']['baz']) == {sid2: '456'}
'baz': {'456': True},
}
def test_disconnect_default_namespace(self): def test_disconnect_default_namespace(self):
self.bm.connect('123', '/') sid1 = self.bm.connect('123', '/')
self.bm.connect('123', '/foo') sid2 = self.bm.connect('123', '/foo')
self.bm.connect('456', '/') sid3 = self.bm.connect('456', '/')
self.bm.connect('456', '/foo') sid4 = self.bm.connect('456', '/foo')
assert self.bm.is_connected('123', '/') assert self.bm.is_connected(sid1, '/')
assert self.bm.is_connected('123', '/foo') assert self.bm.is_connected(sid2, '/foo')
self.bm.disconnect('123', '/') self.bm.disconnect(sid1, '/')
assert not self.bm.is_connected('123', '/') assert not self.bm.is_connected(sid1, '/')
assert self.bm.is_connected('123', '/foo') assert self.bm.is_connected(sid2, '/foo')
self.bm.disconnect('123', '/foo') self.bm.disconnect(sid2, '/foo')
assert not self.bm.is_connected('123', '/foo') assert not self.bm.is_connected(sid2, '/foo')
assert self.bm.rooms['/'] == { assert dict(self.bm.rooms['/'][None]) == {sid3: '456'}
None: {'456': True}, assert dict(self.bm.rooms['/'][sid3]) == {sid3: '456'}
'456': {'456': True}, assert dict(self.bm.rooms['/foo'][None]) == {sid4: '456'}
} assert dict(self.bm.rooms['/foo'][sid4]) == {sid4: '456'}
assert self.bm.rooms['/foo'] == {
None: {'456': True},
'456': {'456': True},
}
def test_disconnect_twice(self): def test_disconnect_twice(self):
self.bm.connect('123', '/') sid1 = self.bm.connect('123', '/')
self.bm.connect('123', '/foo') sid2 = self.bm.connect('123', '/foo')
self.bm.connect('456', '/') sid3 = self.bm.connect('456', '/')
self.bm.connect('456', '/foo') sid4 = self.bm.connect('456', '/foo')
self.bm.disconnect('123', '/') self.bm.disconnect(sid1, '/')
self.bm.disconnect('123', '/foo') self.bm.disconnect(sid2, '/foo')
self.bm.disconnect('123', '/') self.bm.disconnect(sid1, '/')
self.bm.disconnect('123', '/foo') self.bm.disconnect(sid2, '/foo')
assert self.bm.rooms['/'] == { assert dict(self.bm.rooms['/'][None]) == {sid3: '456'}
None: {'456': True}, assert dict(self.bm.rooms['/'][sid3]) == {sid3: '456'}
'456': {'456': True}, assert dict(self.bm.rooms['/foo'][None]) == {sid4: '456'}
} assert dict(self.bm.rooms['/foo'][sid4]) == {sid4: '456'}
assert self.bm.rooms['/foo'] == {
None: {'456': True},
'456': {'456': True},
}
def test_disconnect_all(self): def test_disconnect_all(self):
self.bm.connect('123', '/foo') sid1 = self.bm.connect('123', '/foo')
self.bm.connect('456', '/foo') sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room('123', '/foo', 'bar') self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.enter_room('456', '/foo', 'baz') self.bm.enter_room(sid2, '/foo', 'baz')
self.bm.disconnect('123', '/foo') self.bm.disconnect(sid1, '/foo')
self.bm.disconnect('456', '/foo') self.bm.disconnect(sid2, '/foo')
assert self.bm.rooms == {} assert self.bm.rooms == {}
def test_disconnect_with_callbacks(self): def test_disconnect_with_callbacks(self):
@ -152,12 +149,12 @@ class TestBaseManager(unittest.TestCase):
def test_get_participants(self): def test_get_participants(self):
self.bm.connect('123', '/') self.bm.connect('123', '/')
self.bm.connect('456', '/') self.bm.connect('456', '/')
self.bm.connect('789', '/') sid = self.bm.connect('789', '/')
self.bm.disconnect('789', '/') self.bm.disconnect(sid, '/')
assert '789' not in self.bm.rooms['/'][None] assert sid not in self.bm.rooms['/'][None]
participants = list(self.bm.get_participants('/', None)) participants = list(self.bm.get_participants('/', None))
assert len(participants) == 2 assert len(participants) == 2
assert '789' not in participants assert sid not in participants
def test_leave_invalid_room(self): def test_leave_invalid_room(self):
self.bm.connect('123', '/foo') self.bm.connect('123', '/foo')
@ -169,11 +166,11 @@ class TestBaseManager(unittest.TestCase):
assert [] == rooms assert [] == rooms
def test_close_room(self): def test_close_room(self):
self.bm.connect('123', '/foo') sid1 = self.bm.connect('123', '/foo')
self.bm.connect('456', '/foo') self.bm.connect('456', '/foo')
self.bm.connect('789', '/foo') self.bm.connect('789', '/foo')
self.bm.enter_room('123', '/foo', 'bar') self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.enter_room('123', '/foo', 'bar') self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.close_room('bar', '/foo') self.bm.close_room('bar', '/foo')
assert 'bar' not in self.bm.rooms['/foo'] assert 'bar' not in self.bm.rooms['/foo']
@ -181,26 +178,26 @@ class TestBaseManager(unittest.TestCase):
self.bm.close_room('bar', '/foo') self.bm.close_room('bar', '/foo')
def test_rooms(self): def test_rooms(self):
self.bm.connect('123', '/foo') sid = self.bm.connect('123', '/foo')
self.bm.enter_room('123', '/foo', 'bar') self.bm.enter_room(sid, '/foo', 'bar')
r = self.bm.get_rooms('123', '/foo') r = self.bm.get_rooms(sid, '/foo')
assert len(r) == 2 assert len(r) == 2
assert '123' in r assert sid in r
assert 'bar' in r assert 'bar' in r
def test_emit_to_sid(self): def test_emit_to_sid(self):
self.bm.connect('123', '/foo') sid = self.bm.connect('123', '/foo')
self.bm.connect('456', '/foo') self.bm.connect('456', '/foo')
self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo', room='123') self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo', room=sid)
self.bm.server._emit_internal.assert_called_once_with( self.bm.server._emit_internal.assert_called_once_with(
'123', 'my event', {'foo': 'bar'}, '/foo', None '123', 'my event', {'foo': 'bar'}, '/foo', None
) )
def test_emit_to_room(self): def test_emit_to_room(self):
self.bm.connect('123', '/foo') sid1 = self.bm.connect('123', '/foo')
self.bm.enter_room('123', '/foo', 'bar') self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.connect('456', '/foo') sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room('456', '/foo', 'bar') self.bm.enter_room(sid2, '/foo', 'bar')
self.bm.connect('789', '/foo') self.bm.connect('789', '/foo')
self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo', room='bar') self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo', room='bar')
assert self.bm.server._emit_internal.call_count == 2 assert self.bm.server._emit_internal.call_count == 2
@ -212,10 +209,10 @@ class TestBaseManager(unittest.TestCase):
) )
def test_emit_to_all(self): def test_emit_to_all(self):
self.bm.connect('123', '/foo') sid1 = self.bm.connect('123', '/foo')
self.bm.enter_room('123', '/foo', 'bar') self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.connect('456', '/foo') sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room('456', '/foo', 'bar') self.bm.enter_room(sid2, '/foo', 'bar')
self.bm.connect('789', '/foo') self.bm.connect('789', '/foo')
self.bm.connect('abc', '/bar') self.bm.connect('abc', '/bar')
self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo') self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo')
@ -231,14 +228,14 @@ class TestBaseManager(unittest.TestCase):
) )
def test_emit_to_all_skip_one(self): def test_emit_to_all_skip_one(self):
self.bm.connect('123', '/foo') sid1 = self.bm.connect('123', '/foo')
self.bm.enter_room('123', '/foo', 'bar') self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.connect('456', '/foo') sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room('456', '/foo', 'bar') self.bm.enter_room(sid2, '/foo', 'bar')
self.bm.connect('789', '/foo') self.bm.connect('789', '/foo')
self.bm.connect('abc', '/bar') self.bm.connect('abc', '/bar')
self.bm.emit( self.bm.emit(
'my event', {'foo': 'bar'}, namespace='/foo', skip_sid='456' 'my event', {'foo': 'bar'}, namespace='/foo', skip_sid=sid2
) )
assert self.bm.server._emit_internal.call_count == 2 assert self.bm.server._emit_internal.call_count == 2
self.bm.server._emit_internal.assert_any_call( self.bm.server._emit_internal.assert_any_call(
@ -249,17 +246,17 @@ class TestBaseManager(unittest.TestCase):
) )
def test_emit_to_all_skip_two(self): def test_emit_to_all_skip_two(self):
self.bm.connect('123', '/foo') sid1 = self.bm.connect('123', '/foo')
self.bm.enter_room('123', '/foo', 'bar') self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.connect('456', '/foo') sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room('456', '/foo', 'bar') self.bm.enter_room(sid2, '/foo', 'bar')
self.bm.connect('789', '/foo') sid3 = self.bm.connect('789', '/foo')
self.bm.connect('abc', '/bar') self.bm.connect('abc', '/bar')
self.bm.emit( self.bm.emit(
'my event', 'my event',
{'foo': 'bar'}, {'foo': 'bar'},
namespace='/foo', namespace='/foo',
skip_sid=['123', '789'], skip_sid=[sid1, sid3],
) )
assert self.bm.server._emit_internal.call_count == 1 assert self.bm.server._emit_internal.call_count == 1
self.bm.server._emit_internal.assert_any_call( self.bm.server._emit_internal.assert_any_call(
@ -267,13 +264,13 @@ class TestBaseManager(unittest.TestCase):
) )
def test_emit_with_callback(self): def test_emit_with_callback(self):
self.bm.connect('123', '/foo') sid = self.bm.connect('123', '/foo')
self.bm._generate_ack_id = mock.MagicMock() self.bm._generate_ack_id = mock.MagicMock()
self.bm._generate_ack_id.return_value = 11 self.bm._generate_ack_id.return_value = 11
self.bm.emit( self.bm.emit(
'my event', {'foo': 'bar'}, namespace='/foo', callback='cb' 'my event', {'foo': 'bar'}, namespace='/foo', callback='cb'
) )
self.bm._generate_ack_id.assert_called_once_with('123', '/foo', 'cb') self.bm._generate_ack_id.assert_called_once_with(sid, '/foo', 'cb')
self.bm.server._emit_internal.assert_called_once_with( self.bm.server._emit_internal.assert_called_once_with(
'123', 'my event', {'foo': 'bar'}, '/foo', 11 '123', 'my event', {'foo': 'bar'}, '/foo', 11
) )

Loading…
Cancel
Save