diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index af896fb..ddc190c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -26,7 +26,7 @@ jobs: exclude: # pypy3 currently fails to run on Windows - os: windows-latest - python: pypy-3.8 + python: pypy-3.9 fail-fast: false runs-on: ${{ matrix.os }} steps: diff --git a/docs/server.rst b/docs/server.rst index 5a6798c..5ce59f0 100644 --- a/docs/server.rst +++ b/docs/server.rst @@ -617,6 +617,41 @@ callbacks when emitting. When the external process needs to receive callbacks, using a client to connect to the server with read and write support is a better option than a write-only client manager. +Monitoring and Administration +----------------------------- + +The Socket.IO server can be configured to accept connections from the official +`Socket.IO Admin UI <https://socket.io/docs/v4/admin-ui/>`_. This tool provides +real-time information about currently connected clients, rooms in use and +events being emitted. It also allows an administrator to manually emit events, +change room assignments and disconnect clients. The hosted version of this tool +is available at `https://admin.socket.io <https://admin.socket.io>`_. + +Given that enabling this feature can affect the performance of the server, it +is disabled by default. To enable it, call the +:func:`instrument() <socketio.Server.instrument>` method. For example:: + + import os + import socketio + + sio = socketio.Server(cors_allowed_origins=[ + 'http://localhost:5000', + 'https://admin.socket.io', + ]) + sio.instrument(auth={ + 'username': 'admin', + 'password': os.environ['ADMIN_PASSWORD'], + }) + +This configures the server to accept connections from the hosted Admin UI +client. Administrators can then open https://admin.socket.io in their web +browsers and log in with username ``admin`` and the password given by the +``ADMIN_PASSWORD`` environment variable. To ensure the Admin UI front end is +allowed to connect, CORS is also configured. + +Consult the reference documentation to learn about additional configuration +options that are available. + Debugging and Troubleshooting ----------------------------- diff --git a/examples/server/asgi/app.py b/examples/server/asgi/app.py index 22180bb..36af85f 100644 --- a/examples/server/asgi/app.py +++ b/examples/server/asgi/app.py @@ -1,9 +1,25 @@ #!/usr/bin/env python -import uvicorn +# set instrument to `True` to accept connections from the official Socket.IO +# Admin UI hosted at https://admin.socket.io +instrument = False +admin_login = { + 'username': 'admin', + 'password': 'python', # change this to a strong secret for production use! +} + +import uvicorn import socketio -sio = socketio.AsyncServer(async_mode='asgi') +sio = socketio.AsyncServer( + async_mode='asgi', + cors_allowed_origins=None if not instrument else [ + 'http://localhost:5000', + 'https://admin.socket.io', # edit the allowed origins if necessary + ]) +if instrument: + sio.instrument(auth=admin_login) + app = socketio.ASGIApp(sio, static_files={ '/': 'app.html', }) diff --git a/examples/server/wsgi/app.py b/examples/server/wsgi/app.py index 3339826..7b019fd 100644 --- a/examples/server/wsgi/app.py +++ b/examples/server/wsgi/app.py @@ -3,10 +3,26 @@ # installed async_mode = None +# set instrument to `True` to accept connections from the official Socket.IO +# Admin UI hosted at https://admin.socket.io +instrument = False +admin_login = { + 'username': 'admin', + 'password': 'python', # change this to a strong secret for production use! +} + from flask import Flask, render_template import socketio -sio = socketio.Server(logger=True, async_mode=async_mode) +sio = socketio.Server( + async_mode=async_mode, + cors_allowed_origins=None if not instrument else [ + 'http://localhost:5000', + 'https://admin.socket.io', # edit the allowed origins if necessary + ]) +if instrument: + sio.instrument(auth=admin_login) + app = Flask(__name__) app.wsgi_app = socketio.WSGIApp(sio, app.wsgi_app) app.config['SECRET_KEY'] = 'secret!' diff --git a/examples/server/wsgi/templates/index.html b/examples/server/wsgi/templates/index.html index 7c9ae41..bec1a62 100644 --- a/examples/server/wsgi/templates/index.html +++ b/examples/server/wsgi/templates/index.html @@ -6,7 +6,7 @@ <script type="text/javascript" src="//cdnjs.cloudflare.com/ajax/libs/socket.io/4.7.2/socket.io.min.js"></script> <script type="text/javascript" charset="utf-8"> $(document).ready(function(){ - var socket = io.connect({transports: ['websocket']}); + var socket = io.connect(); socket.on('connect', function() { socket.emit('my_event', {data: 'I\'m connected!'}); diff --git a/src/socketio/admin.py b/src/socketio/admin.py new file mode 100644 index 0000000..f317ea2 --- /dev/null +++ b/src/socketio/admin.py @@ -0,0 +1,405 @@ +from datetime import datetime +import functools +import os +import socket +import time +from urllib.parse import parse_qs +from .exceptions import ConnectionRefusedError + +HOSTNAME = socket.gethostname() +PID = os.getpid() + + +class EventBuffer: + def __init__(self): + self.buffer = {} + + def push(self, type, count=1): + timestamp = int(time.time()) * 1000 + key = '{};{}'.format(timestamp, type) + if key not in self.buffer: + self.buffer[key] = { + 'timestamp': timestamp, + 'type': type, + 'count': count, + } + else: + self.buffer[key]['count'] += count + + def get_and_clear(self): + buffer = self.buffer + self.buffer = {} + return [value for value in buffer.values()] + + +class InstrumentedServer: + def __init__(self, sio, auth=None, mode='development', read_only=False, + server_id=None, namespace='/admin', server_stats_interval=2): + """Instrument the Socket.IO server for monitoring with the `Socket.IO + Admin UI <https://socket.io/docs/v4/admin-ui/>`_. + """ + if auth is None: + raise ValueError('auth must be specified') + self.sio = sio + self.auth = auth + self.admin_namespace = namespace + self.read_only = read_only + self.server_id = server_id or ( + self.sio.manager.host_id if hasattr(self.sio.manager, 'host_id') + else HOSTNAME + ) + self.mode = mode + self.server_stats_interval = server_stats_interval + self.event_buffer = EventBuffer() + + # task that emits "server_stats" every 2 seconds + self.stop_stats_event = None + self.stats_task = None + + # monkey-patch the server to report metrics to the admin UI + self.instrument() + + def instrument(self): + self.sio.on('connect', self.admin_connect, + namespace=self.admin_namespace) + + if self.mode == 'development': + if not self.read_only: # pragma: no branch + self.sio.on('emit', self.admin_emit, + namespace=self.admin_namespace) + self.sio.on('join', self.admin_enter_room, + namespace=self.admin_namespace) + self.sio.on('leave', self.admin_leave_room, + namespace=self.admin_namespace) + self.sio.on('_disconnect', self.admin_disconnect, + namespace=self.admin_namespace) + + # track socket connection times + self.sio.manager._timestamps = {} + + # report socket.io connections + self.sio.manager.__connect = self.sio.manager.connect + self.sio.manager.connect = self._connect + + # report socket.io disconnection + self.sio.manager.__disconnect = self.sio.manager.disconnect + self.sio.manager.disconnect = self._disconnect + + # report join rooms + self.sio.manager.__basic_enter_room = \ + self.sio.manager.basic_enter_room + self.sio.manager.basic_enter_room = self._basic_enter_room + + # report leave rooms + self.sio.manager.__basic_leave_room = \ + self.sio.manager.basic_leave_room + self.sio.manager.basic_leave_room = self._basic_leave_room + + # report emit events + self.sio.manager.__emit = self.sio.manager.emit + self.sio.manager.emit = self._emit + + # report receive events + self.sio.__handle_event_internal = self.sio._handle_event_internal + self.sio._handle_event_internal = self._handle_event_internal + + # report engine.io connections + self.sio.eio.on('connect', self._handle_eio_connect) + self.sio.eio.on('disconnect', self._handle_eio_disconnect) + + # report polling packets + from engineio.socket import Socket + self.sio.eio.__ok = self.sio.eio._ok + self.sio.eio._ok = self._eio_http_response + Socket.__handle_post_request = Socket.handle_post_request + Socket.handle_post_request = functools.partialmethod( + self.__class__._eio_handle_post_request, self) + + # report websocket packets + Socket.__websocket_handler = Socket._websocket_handler + Socket._websocket_handler = functools.partialmethod( + self.__class__._eio_websocket_handler, self) + + # report connected sockets with each ping + if self.mode == 'development': + Socket.__send_ping = Socket._send_ping + Socket._send_ping = functools.partialmethod( + self.__class__._eio_send_ping, self) + + def uninstrument(self): # pragma: no cover + if self.mode == 'development': + self.sio.manager.connect = self.sio.manager.__connect + self.sio.manager.disconnect = self.sio.manager.__disconnect + self.sio.manager.basic_enter_room = \ + self.sio.manager.__basic_enter_room + self.sio.manager.basic_leave_room = \ + self.sio.manager.__basic_leave_room + self.sio.manager.emit = self.sio.manager.__emit + self.sio._handle_event_internal = self.sio.__handle_event_internal + self.sio.eio._ok = self.sio.eio.__ok + + from engineio.socket import Socket + Socket.handle_post_request = Socket.__handle_post_request + Socket._websocket_handler = Socket.__websocket_handler + if self.mode == 'development': + Socket._send_ping = Socket.__send_ping + + def admin_connect(self, sid, environ, client_auth): + if self.auth: + authenticated = False + if isinstance(self.auth, dict): + authenticated = client_auth == self.auth + elif isinstance(self.auth, list): + authenticated = client_auth in self.auth + else: + authenticated = self.auth(client_auth) + if not authenticated: + raise ConnectionRefusedError('authentication failed') + + def config(sid): + self.sio.sleep(0.1) + + # supported features + features = ['AGGREGATED_EVENTS'] + if not self.read_only: + features += ['EMIT', 'JOIN', 'LEAVE', 'DISCONNECT', 'MJOIN', + 'MLEAVE', 'MDISCONNECT'] + if self.mode == 'development': + features.append('ALL_EVENTS') + self.sio.emit('config', {'supportedFeatures': features}, + to=sid, namespace=self.admin_namespace) + + # send current sockets + if self.mode == 'development': + all_sockets = [] + for nsp in self.sio.manager.get_namespaces(): + for sid, eio_sid in self.sio.manager.get_participants( + nsp, None): + all_sockets.append( + self.serialize_socket(sid, nsp, eio_sid)) + self.sio.emit('all_sockets', all_sockets, to=sid, + namespace=self.admin_namespace) + + self.sio.start_background_task(config, sid) + + def admin_emit(self, _, namespace, room_filter, event, *data): + self.sio.emit(event, data, to=room_filter, namespace=namespace) + + def admin_enter_room(self, _, namespace, room, room_filter=None): + for sid, _ in self.sio.manager.get_participants( + namespace, room_filter): + self.sio.enter_room(sid, room, namespace=namespace) + + def admin_leave_room(self, _, namespace, room, room_filter=None): + for sid, _ in self.sio.manager.get_participants( + namespace, room_filter): + self.sio.leave_room(sid, room, namespace=namespace) + + def admin_disconnect(self, _, namespace, close, room_filter=None): + for sid, _ in self.sio.manager.get_participants( + namespace, room_filter): + self.sio.disconnect(sid, namespace=namespace) + + def shutdown(self): + if self.stats_task: # pragma: no branch + self.stop_stats_event.set() + self.stats_task.join() + + def _connect(self, eio_sid, namespace): + sid = self.sio.manager.__connect(eio_sid, namespace) + t = time.time() + self.sio.manager._timestamps[sid] = t + serialized_socket = self.serialize_socket(sid, namespace, eio_sid) + self.sio.emit('socket_connected', ( + serialized_socket, + datetime.utcfromtimestamp(t).isoformat() + 'Z', + ), namespace=self.admin_namespace) + return sid + + def _disconnect(self, sid, namespace, **kwargs): + del self.sio.manager._timestamps[sid] + self.sio.emit('socket_disconnected', ( + namespace, + sid, + 'N/A', + datetime.utcnow().isoformat() + 'Z', + ), namespace=self.admin_namespace) + return self.sio.manager.__disconnect(sid, namespace, **kwargs) + + def _check_for_upgrade(self, eio_sid, sid, namespace): # pragma: no cover + for _ in range(5): + self.sio.sleep(5) + try: + if self.sio.eio._get_socket(eio_sid).upgraded: + self.sio.emit('socket_updated', { + 'id': sid, + 'nsp': namespace, + 'transport': 'websocket', + }, namespace=self.admin_namespace) + break + except KeyError: + pass + + def _basic_enter_room(self, sid, namespace, room, eio_sid=None): + ret = self.sio.manager.__basic_enter_room(sid, namespace, room, + eio_sid) + if room: + self.sio.emit('room_joined', ( + namespace, + room, + sid, + datetime.utcnow().isoformat() + 'Z', + ), namespace=self.admin_namespace) + return ret + + def _basic_leave_room(self, sid, namespace, room): + if room: + self.sio.emit('room_left', ( + namespace, + room, + sid, + datetime.utcnow().isoformat() + 'Z', + ), namespace=self.admin_namespace) + return self.sio.manager.__basic_leave_room(sid, namespace, room) + + def _emit(self, event, data, namespace, room=None, skip_sid=None, + callback=None, **kwargs): + ret = self.sio.manager.__emit(event, data, namespace, room=room, + skip_sid=skip_sid, callback=callback, + **kwargs) + if namespace != self.admin_namespace: + event_data = [event] + list(data) if isinstance(data, tuple) \ + else [data] + if not isinstance(skip_sid, list): # pragma: no branch + skip_sid = [skip_sid] + for sid, _ in self.sio.manager.get_participants(namespace, room): + if sid not in skip_sid: + self.sio.emit('event_sent', ( + namespace, + sid, + event_data, + datetime.utcnow().isoformat() + 'Z', + ), namespace=self.admin_namespace) + return ret + + def _handle_event_internal(self, server, sid, eio_sid, data, namespace, + id): + ret = self.sio.__handle_event_internal(server, sid, eio_sid, data, + namespace, id) + self.sio.emit('event_received', ( + namespace, + sid, + data, + datetime.utcnow().isoformat() + 'Z', + ), namespace=self.admin_namespace) + return ret + + def _handle_eio_connect(self, eio_sid, environ): + if self.stop_stats_event is None: + self.stop_stats_event = self.sio.eio.create_event() + self.stats_task = self.sio.start_background_task( + self._emit_server_stats) + + self.event_buffer.push('rawConnection') + return self.sio._handle_eio_connect(eio_sid, environ) + + def _handle_eio_disconnect(self, eio_sid): + self.event_buffer.push('rawDisconnection') + return self.sio._handle_eio_disconnect(eio_sid) + + def _eio_http_response(self, packets=None, headers=None, jsonp_index=None): + ret = self.sio.eio.__ok(packets=packets, headers=headers, + jsonp_index=jsonp_index) + self.event_buffer.push('packetsOut') + self.event_buffer.push('bytesOut', len(ret['response'])) + return ret + + def _eio_handle_post_request(socket, self, environ): + ret = socket.__handle_post_request(environ) + self.event_buffer.push('packetsIn') + self.event_buffer.push( + 'bytesIn', int(environ.get('CONTENT_LENGTH', 0))) + return ret + + def _eio_websocket_handler(socket, self, ws): + def _send(ws, data, *args, **kwargs): + self.event_buffer.push('packetsOut') + self.event_buffer.push('bytesOut', len(data)) + return ws.__send(data, *args, **kwargs) + + def _wait(ws): + ret = ws.__wait() + self.event_buffer.push('packetsIn') + self.event_buffer.push('bytesIn', len(ret or '')) + return ret + + ws.__send = ws.send + ws.send = functools.partial(_send, ws) + ws.__wait = ws.wait + ws.wait = functools.partial(_wait, ws) + return socket.__websocket_handler(ws) + + def _eio_send_ping(socket, self): # pragma: no cover + eio_sid = socket.sid + t = time.time() + for namespace in self.sio.manager.get_namespaces(): + sid = self.sio.manager.sid_from_eio_sid(eio_sid, namespace) + if sid: + serialized_socket = self.serialize_socket(sid, namespace, + eio_sid) + self.sio.emit('socket_connected', ( + serialized_socket, + datetime.utcfromtimestamp(t).isoformat() + 'Z', + ), namespace=self.admin_namespace) + return socket.__send_ping() + + def _emit_server_stats(self): + start_time = time.time() + namespaces = list(self.sio.handlers.keys()) + namespaces.sort() + while not self.stop_stats_event.is_set(): + self.sio.sleep(self.server_stats_interval) + self.sio.emit('server_stats', { + 'serverId': self.server_id, + 'hostname': HOSTNAME, + 'pid': PID, + 'uptime': time.time() - start_time, + 'clientsCount': len(self.sio.eio.sockets), + 'pollingClientsCount': len( + [s for s in self.sio.eio.sockets.values() + if not s.upgraded]), + 'aggregatedEvents': self.event_buffer.get_and_clear(), + 'namespaces': [{ + 'name': nsp, + 'socketsCount': len(self.sio.manager.rooms.get( + nsp, {None: []}).get(None, [])) + } for nsp in namespaces], + }, namespace=self.admin_namespace) + + def serialize_socket(self, sid, namespace, eio_sid=None): + if eio_sid is None: # pragma: no cover + eio_sid = self.sio.manager.eio_sid_from_sid(sid) + socket = self.sio.eio._get_socket(eio_sid) + environ = self.sio.environ.get(eio_sid, {}) + tm = self.sio.manager._timestamps[sid] if sid in \ + self.sio.manager._timestamps else 0 + return { + 'id': sid, + 'clientId': eio_sid, + 'transport': 'websocket' if socket.upgraded else 'polling', + 'nsp': namespace, + 'data': {}, + 'handshake': { + 'address': environ.get('REMOTE_ADDR', ''), + 'headers': {k[5:].lower(): v for k, v in environ.items() + if k.startswith('HTTP_')}, + 'query': {k: v[0] if len(v) == 1 else v for k, v in parse_qs( + environ.get('QUERY_STRING', '')).items()}, + 'secure': environ.get('wsgi.url_scheme', '') == 'https', + 'url': environ.get('PATH_INFO', ''), + 'issued': tm * 1000, + 'time': datetime.utcfromtimestamp(tm).isoformat() + 'Z' + if tm else '', + }, + 'rooms': self.sio.manager.get_rooms(sid, namespace), + } diff --git a/src/socketio/async_admin.py b/src/socketio/async_admin.py new file mode 100644 index 0000000..162c566 --- /dev/null +++ b/src/socketio/async_admin.py @@ -0,0 +1,398 @@ +import asyncio +from datetime import datetime +import functools +import os +import socket +import time +from urllib.parse import parse_qs +from .admin import EventBuffer +from .exceptions import ConnectionRefusedError + +HOSTNAME = socket.gethostname() +PID = os.getpid() + + +class InstrumentedAsyncServer: + def __init__(self, sio, auth=None, namespace='/admin', read_only=False, + server_id=None, mode='development', server_stats_interval=2): + """Instrument the Socket.IO server for monitoring with the `Socket.IO + Admin UI <https://socket.io/docs/v4/admin-ui/>`_. + """ + if auth is None: + raise ValueError('auth must be specified') + self.sio = sio + self.auth = auth + self.admin_namespace = namespace + self.read_only = read_only + self.server_id = server_id or ( + self.sio.manager.host_id if hasattr(self.sio.manager, 'host_id') + else HOSTNAME + ) + self.mode = mode + self.server_stats_interval = server_stats_interval + self.admin_queue = [] + self.event_buffer = EventBuffer() + + # task that emits "server_stats" every 2 seconds + self.stop_stats_event = None + self.stats_task = None + + # monkey-patch the server to report metrics to the admin UI + self.instrument() + + def instrument(self): + self.sio.on('connect', self.admin_connect, + namespace=self.admin_namespace) + + if self.mode == 'development': + if not self.read_only: # pragma: no branch + self.sio.on('emit', self.admin_emit, + namespace=self.admin_namespace) + self.sio.on('join', self.admin_enter_room, + namespace=self.admin_namespace) + self.sio.on('leave', self.admin_leave_room, + namespace=self.admin_namespace) + self.sio.on('_disconnect', self.admin_disconnect, + namespace=self.admin_namespace) + + # track socket connection times + self.sio.manager._timestamps = {} + + # report socket.io connections + self.sio.manager.__connect = self.sio.manager.connect + self.sio.manager.connect = self._connect + + # report socket.io disconnection + self.sio.manager.__disconnect = self.sio.manager.disconnect + self.sio.manager.disconnect = self._disconnect + + # report join rooms + self.sio.manager.__basic_enter_room = \ + self.sio.manager.basic_enter_room + self.sio.manager.basic_enter_room = self._basic_enter_room + + # report leave rooms + self.sio.manager.__basic_leave_room = \ + self.sio.manager.basic_leave_room + self.sio.manager.basic_leave_room = self._basic_leave_room + + # report emit events + self.sio.manager.__emit = self.sio.manager.emit + self.sio.manager.emit = self._emit + + # report receive events + self.sio.__handle_event_internal = self.sio._handle_event_internal + self.sio._handle_event_internal = self._handle_event_internal + + # report engine.io connections + self.sio.eio.on('connect', self._handle_eio_connect) + self.sio.eio.on('disconnect', self._handle_eio_disconnect) + + # report polling packets + from engineio.async_socket import AsyncSocket + self.sio.eio.__ok = self.sio.eio._ok + self.sio.eio._ok = self._eio_http_response + AsyncSocket.__handle_post_request = AsyncSocket.handle_post_request + AsyncSocket.handle_post_request = functools.partialmethod( + self.__class__._eio_handle_post_request, self) + + # report websocket packets + AsyncSocket.__websocket_handler = AsyncSocket._websocket_handler + AsyncSocket._websocket_handler = functools.partialmethod( + self.__class__._eio_websocket_handler, self) + + # report connected sockets with each ping + if self.mode == 'development': + AsyncSocket.__send_ping = AsyncSocket._send_ping + AsyncSocket._send_ping = functools.partialmethod( + self.__class__._eio_send_ping, self) + + def uninstrument(self): # pragma: no cover + if self.mode == 'development': + self.sio.manager.connect = self.sio.manager.__connect + self.sio.manager.disconnect = self.sio.manager.__disconnect + self.sio.manager.basic_enter_room = \ + self.sio.manager.__basic_enter_room + self.sio.manager.basic_leave_room = \ + self.sio.manager.__basic_leave_room + self.sio.manager.emit = self.sio.manager.__emit + self.sio._handle_event_internal = self.sio.__handle_event_internal + self.sio.eio._ok = self.sio.eio.__ok + + from engineio.async_socket import AsyncSocket + AsyncSocket.handle_post_request = AsyncSocket.__handle_post_request + AsyncSocket._websocket_handler = AsyncSocket.__websocket_handler + if self.mode == 'development': + AsyncSocket._send_ping = AsyncSocket.__send_ping + + async def admin_connect(self, sid, environ, client_auth): + authenticated = True + if self.auth: + authenticated = False + if isinstance(self.auth, dict): + authenticated = client_auth == self.auth + elif isinstance(self.auth, list): + authenticated = client_auth in self.auth + else: + if asyncio.iscoroutinefunction(self.auth): + authenticated = await self.auth(client_auth) + else: + authenticated = self.auth(client_auth) + if not authenticated: + raise ConnectionRefusedError('authentication failed') + + async def config(sid): + await self.sio.sleep(0.1) + + # supported features + features = ['AGGREGATED_EVENTS'] + if not self.read_only: + features += ['EMIT', 'JOIN', 'LEAVE', 'DISCONNECT', 'MJOIN', + 'MLEAVE', 'MDISCONNECT'] + if self.mode == 'development': + features.append('ALL_EVENTS') + await self.sio.emit('config', {'supportedFeatures': features}, + to=sid, namespace=self.admin_namespace) + + # send current sockets + if self.mode == 'development': + all_sockets = [] + for nsp in self.sio.manager.get_namespaces(): + for sid, eio_sid in self.sio.manager.get_participants( + nsp, None): + all_sockets.append( + self.serialize_socket(sid, nsp, eio_sid)) + await self.sio.emit('all_sockets', all_sockets, to=sid, + namespace=self.admin_namespace) + + self.sio.start_background_task(config, sid) + self.stop_stats_event = self.sio.eio.create_event() + self.stats_task = self.sio.start_background_task( + self._emit_server_stats) + + async def admin_emit(self, _, namespace, room_filter, event, *data): + await self.sio.emit(event, data, to=room_filter, namespace=namespace) + + async def admin_enter_room(self, _, namespace, room, room_filter=None): + for sid, _ in self.sio.manager.get_participants( + namespace, room_filter): + await self.sio.enter_room(sid, room, namespace=namespace) + + async def admin_leave_room(self, _, namespace, room, room_filter=None): + for sid, _ in self.sio.manager.get_participants( + namespace, room_filter): + await self.sio.leave_room(sid, room, namespace=namespace) + + async def admin_disconnect(self, _, namespace, close, room_filter=None): + for sid, _ in self.sio.manager.get_participants( + namespace, room_filter): + await self.sio.disconnect(sid, namespace=namespace) + + async def shutdown(self): + if self.stats_task: # pragma: no branch + self.stop_stats_event.set() + await asyncio.gather(self.stats_task) + + async def _connect(self, eio_sid, namespace): + sid = await self.sio.manager.__connect(eio_sid, namespace) + t = time.time() + self.sio.manager._timestamps[sid] = t + serialized_socket = self.serialize_socket(sid, namespace, eio_sid) + await self.sio.emit('socket_connected', ( + serialized_socket, + datetime.utcfromtimestamp(t).isoformat() + 'Z', + ), namespace=self.admin_namespace) + return sid + + async def _disconnect(self, sid, namespace, **kwargs): + del self.sio.manager._timestamps[sid] + await self.sio.emit('socket_disconnected', ( + namespace, + sid, + 'N/A', + datetime.utcnow().isoformat() + 'Z', + ), namespace=self.admin_namespace) + return await self.sio.manager.__disconnect(sid, namespace, **kwargs) + + async def _check_for_upgrade(self, eio_sid, sid, + namespace): # pragma: no cover + for _ in range(5): + await self.sio.sleep(5) + try: + if self.sio.eio._get_socket(eio_sid).upgraded: + await self.sio.emit('socket_updated', { + 'id': sid, + 'nsp': namespace, + 'transport': 'websocket', + }, namespace=self.admin_namespace) + break + except KeyError: + pass + + def _basic_enter_room(self, sid, namespace, room, eio_sid=None): + ret = self.sio.manager.__basic_enter_room(sid, namespace, room, + eio_sid) + if room: + self.admin_queue.append(('room_joined', ( + namespace, + room, + sid, + datetime.utcnow().isoformat() + 'Z', + ))) + return ret + + def _basic_leave_room(self, sid, namespace, room): + if room: + self.admin_queue.append(('room_left', ( + namespace, + room, + sid, + datetime.utcnow().isoformat() + 'Z', + ))) + return self.sio.manager.__basic_leave_room(sid, namespace, room) + + async def _emit(self, event, data, namespace, room=None, skip_sid=None, + callback=None, **kwargs): + ret = await self.sio.manager.__emit( + event, data, namespace, room=room, skip_sid=skip_sid, + callback=callback, **kwargs) + if namespace != self.admin_namespace: + event_data = [event] + list(data) if isinstance(data, tuple) \ + else [data] + if not isinstance(skip_sid, list): # pragma: no branch + skip_sid = [skip_sid] + for sid, _ in self.sio.manager.get_participants(namespace, room): + if sid not in skip_sid: + await self.sio.emit('event_sent', ( + namespace, + sid, + event_data, + datetime.utcnow().isoformat() + 'Z', + ), namespace=self.admin_namespace) + return ret + + async def _handle_event_internal(self, server, sid, eio_sid, data, + namespace, id): + ret = await self.sio.__handle_event_internal(server, sid, eio_sid, + data, namespace, id) + await self.sio.emit('event_received', ( + namespace, + sid, + data, + datetime.utcnow().isoformat() + 'Z', + ), namespace=self.admin_namespace) + return ret + + async def _handle_eio_connect(self, eio_sid, environ): + if self.stop_stats_event is None: + self.stop_stats_event = self.sio.eio.create_event() + self.stats_task = self.sio.start_background_task( + self._emit_server_stats) + + self.event_buffer.push('rawConnection') + return await self.sio._handle_eio_connect(eio_sid, environ) + + async def _handle_eio_disconnect(self, eio_sid): + self.event_buffer.push('rawDisconnection') + return await self.sio._handle_eio_disconnect(eio_sid) + + def _eio_http_response(self, packets=None, headers=None, jsonp_index=None): + ret = self.sio.eio.__ok(packets=packets, headers=headers, + jsonp_index=jsonp_index) + self.event_buffer.push('packetsOut') + self.event_buffer.push('bytesOut', len(ret['response'])) + return ret + + async def _eio_handle_post_request(socket, self, environ): + ret = await socket.__handle_post_request(environ) + self.event_buffer.push('packetsIn') + self.event_buffer.push( + 'bytesIn', int(environ.get('CONTENT_LENGTH', 0))) + return ret + + async def _eio_websocket_handler(socket, self, ws): + async def _send(ws, data): + self.event_buffer.push('packetsOut') + self.event_buffer.push('bytesOut', len(data)) + return await ws.__send(data) + + async def _wait(ws): + ret = await ws.__wait() + self.event_buffer.push('packetsIn') + self.event_buffer.push('bytesIn', len(ret or '')) + return ret + + ws.__send = ws.send + ws.send = functools.partial(_send, ws) + ws.__wait = ws.wait + ws.wait = functools.partial(_wait, ws) + return await socket.__websocket_handler(ws) + + async def _eio_send_ping(socket, self): # pragma: no cover + eio_sid = socket.sid + t = time.time() + for namespace in self.sio.manager.get_namespaces(): + sid = self.sio.manager.sid_from_eio_sid(eio_sid, namespace) + if sid: + serialized_socket = self.serialize_socket(sid, namespace, + eio_sid) + await self.sio.emit('socket_connected', ( + serialized_socket, + datetime.utcfromtimestamp(t).isoformat() + 'Z', + ), namespace=self.admin_namespace) + return await socket.__send_ping() + + async def _emit_server_stats(self): + start_time = time.time() + namespaces = list(self.sio.handlers.keys()) + namespaces.sort() + while not self.stop_stats_event.is_set(): + await self.sio.sleep(self.server_stats_interval) + await self.sio.emit('server_stats', { + 'serverId': self.server_id, + 'hostname': HOSTNAME, + 'pid': PID, + 'uptime': time.time() - start_time, + 'clientsCount': len(self.sio.eio.sockets), + 'pollingClientsCount': len( + [s for s in self.sio.eio.sockets.values() + if not s.upgraded]), + 'aggregatedEvents': self.event_buffer.get_and_clear(), + 'namespaces': [{ + 'name': nsp, + 'socketsCount': len(self.sio.manager.rooms.get( + nsp, {None: []}).get(None, [])) + } for nsp in namespaces], + }, namespace=self.admin_namespace) + while self.admin_queue: + event, args = self.admin_queue.pop(0) + await self.sio.emit(event, args, + namespace=self.admin_namespace) + + def serialize_socket(self, sid, namespace, eio_sid=None): + if eio_sid is None: # pragma: no cover + eio_sid = self.sio.manager.eio_sid_from_sid(sid) + socket = self.sio.eio._get_socket(eio_sid) + environ = self.sio.environ.get(eio_sid, {}) + tm = self.sio.manager._timestamps[sid] if sid in \ + self.sio.manager._timestamps else 0 + return { + 'id': sid, + 'clientId': eio_sid, + 'transport': 'websocket' if socket.upgraded else 'polling', + 'nsp': namespace, + 'data': {}, + 'handshake': { + 'address': environ.get('REMOTE_ADDR', ''), + 'headers': {k[5:].lower(): v for k, v in environ.items() + if k.startswith('HTTP_')}, + 'query': {k: v[0] if len(v) == 1 else v for k, v in parse_qs( + environ.get('QUERY_STRING', '')).items()}, + 'secure': environ.get('wsgi.url_scheme', '') == 'https', + 'url': environ.get('PATH_INFO', ''), + 'issued': tm * 1000, + 'time': datetime.utcfromtimestamp(tm).isoformat() + 'Z' + if tm else '', + }, + 'rooms': self.sio.manager.get_rooms(sid, namespace), + } diff --git a/src/socketio/async_manager.py b/src/socketio/async_manager.py index 6646376..dcf79cf 100644 --- a/src/socketio/async_manager.py +++ b/src/socketio/async_manager.py @@ -62,6 +62,13 @@ class AsyncManager(BaseManager): return await asyncio.wait(tasks) + async def connect(self, eio_sid, namespace): + """Register a client connection to a namespace. + + Note: this method is a coroutine. + """ + return super().connect(eio_sid, namespace) + async def disconnect(self, sid, namespace, **kwargs): """Disconnect a client. diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index f3bb8f8..99af067 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -467,6 +467,45 @@ class AsyncServer(base_server.BaseServer): """ return await self.eio.sleep(seconds) + def instrument(self, auth=None, mode='development', read_only=False, + server_id=None, namespace='/admin', + server_stats_interval=2): + """Instrument the Socket.IO server for monitoring with the `Socket.IO + Admin UI <https://socket.io/docs/v4/admin-ui/>`_. + + :param auth: Authentication credentials for Admin UI access. Set to a + dictionary with the expected login (usually ``username`` + and ``password``) or a list of dictionaries if more than + one set of credentials need to be available. For more + complex authentication methods, set to a callable that + receives the authentication dictionary as an argument and + returns ``True`` if the user is allowed or ``False`` + otherwise. To disable authentication, set this argument to + ``False`` (not recommended, never do this on a production + server). + :param mode: The reporting mode. The default is ``'development'``, + which is best used while debugging, as it may have a + significant performance effect. Set to ``'production'`` to + reduce the amount of information that is reported to the + admin UI. + :param read_only: If set to ``True``, the admin interface will be + read-only, with no option to modify room assignments + or disconnect clients. The default is ``False``. + :param server_id: The server name to use for this server. If this + argument is omitted, the server generates its own + name. + :param namespace: The Socket.IO namespace to use for the admin + interface. The default is ``/admin``. + :param server_stats_interval: The interval in seconds at which the + server emits a summary of it stats to all + connected admins. + """ + from .async_admin import InstrumentedAsyncServer + return InstrumentedAsyncServer( + self, auth=auth, mode=mode, read_only=read_only, + server_id=server_id, namespace=namespace, + server_stats_interval=server_stats_interval) + async def _send_packet(self, eio_sid, pkt): """Send a Socket.IO packet to a client.""" encoded_packet = pkt.encode() @@ -486,7 +525,7 @@ class AsyncServer(base_server.BaseServer): sid = None if namespace in self.handlers or namespace in self.namespace_handlers \ or self.namespaces == '*' or namespace in self.namespaces: - sid = self.manager.connect(eio_sid, namespace) + sid = await self.manager.connect(eio_sid, namespace) if sid is None: await self._send_packet(eio_sid, self.packet_class( packet.CONNECT_ERROR, data='Unable to connect', diff --git a/src/socketio/async_simple_client.py b/src/socketio/async_simple_client.py index b21418a..db90784 100644 --- a/src/socketio/async_simple_client.py +++ b/src/socketio/async_simple_client.py @@ -23,7 +23,8 @@ class AsyncSimpleClient: self.input_buffer = [] async def connect(self, url, headers={}, auth=None, transports=None, - namespace='/', socketio_path='socket.io'): + namespace='/', socketio_path='socket.io', + wait_timeout=5): """Connect to a Socket.IO server. :param url: The URL of the Socket.IO server. It can include custom @@ -49,6 +50,8 @@ class AsyncSimpleClient: :param socketio_path: The endpoint where the Socket.IO server is installed. The default value is appropriate for most cases. + :param wait_timeout: How long the client should wait for the + connection. The default is 5 seconds. Note: this method is a coroutine. """ @@ -80,7 +83,8 @@ class AsyncSimpleClient: await self.client.connect( url, headers=headers, auth=auth, transports=transports, - namespaces=[namespace], socketio_path=socketio_path) + namespaces=[namespace], socketio_path=socketio_path, + wait_timeout=wait_timeout) @property def sid(self): @@ -89,7 +93,7 @@ class AsyncSimpleClient: The session ID is not guaranteed to remain constant throughout the life of the connection, as reconnections can cause it to change. """ - return self.client.sid if self.client else None + return self.client.get_sid(self.namespace) if self.client else None @property def transport(self): diff --git a/src/socketio/base_manager.py b/src/socketio/base_manager.py index 6e145a1..ca4b0b9 100644 --- a/src/socketio/base_manager.py +++ b/src/socketio/base_manager.py @@ -30,7 +30,7 @@ class BaseManager: def get_participants(self, namespace, room): """Return an iterable with the active participants in a room.""" - ns = self.rooms[namespace] + ns = self.rooms.get(namespace, {}) if hasattr(room, '__len__') and not isinstance(room, str): participants = ns[room[0]]._fwdm.copy() if room[0] in ns else {} for r in room[1:]: @@ -126,7 +126,7 @@ class BaseManager: try: for sid, _ in self.get_participants(namespace, room): self.basic_leave_room(sid, namespace, room) - except KeyError: + except KeyError: # pragma: no cover pass def get_rooms(self, sid, namespace): diff --git a/src/socketio/server.py b/src/socketio/server.py index 275e530..2081337 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -454,6 +454,45 @@ class Server(base_server.BaseServer): """ return self.eio.sleep(seconds) + def instrument(self, auth=None, mode='development', read_only=False, + server_id=None, namespace='/admin', + server_stats_interval=2): + """Instrument the Socket.IO server for monitoring with the `Socket.IO + Admin UI <https://socket.io/docs/v4/admin-ui/>`_. + + :param auth: Authentication credentials for Admin UI access. Set to a + dictionary with the expected login (usually ``username`` + and ``password``) or a list of dictionaries if more than + one set of credentials need to be available. For more + complex authentication methods, set to a callable that + receives the authentication dictionary as an argument and + returns ``True`` if the user is allowed or ``False`` + otherwise. To disable authentication, set this argument to + ``False`` (not recommended, never do this on a production + server). + :param mode: The reporting mode. The default is ``'development'``, + which is best used while debugging, as it may have a + significant performance effect. Set to ``'production'`` to + reduce the amount of information that is reported to the + admin UI. + :param read_only: If set to ``True``, the admin interface will be + read-only, with no option to modify room assignments + or disconnect clients. The default is ``False``. + :param server_id: The server name to use for this server. If this + argument is omitted, the server generates its own + name. + :param namespace: The Socket.IO namespace to use for the admin + interface. The default is ``/admin``. + :param server_stats_interval: The interval in seconds at which the + server emits a summary of it stats to all + connected admins. + """ + from .admin import InstrumentedServer + return InstrumentedServer( + self, auth=auth, mode=mode, read_only=read_only, + server_id=server_id, namespace=namespace, + server_stats_interval=server_stats_interval) + def _send_packet(self, eio_sid, pkt): """Send a Socket.IO packet to a client.""" encoded_packet = pkt.encode() diff --git a/src/socketio/simple_client.py b/src/socketio/simple_client.py index 4a88380..ce3a1c5 100644 --- a/src/socketio/simple_client.py +++ b/src/socketio/simple_client.py @@ -23,7 +23,7 @@ class SimpleClient: self.input_buffer = [] def connect(self, url, headers={}, auth=None, transports=None, - namespace='/', socketio_path='socket.io'): + namespace='/', socketio_path='socket.io', wait_timeout=5): """Connect to a Socket.IO server. :param url: The URL of the Socket.IO server. It can include custom @@ -49,6 +49,9 @@ class SimpleClient: :param socketio_path: The endpoint where the Socket.IO server is installed. The default value is appropriate for most cases. + :param wait_timeout: How long the client should wait for the + connection to be established. The default is 5 + seconds. """ if self.connected: raise RuntimeError('Already connected') @@ -78,7 +81,8 @@ class SimpleClient: self.client.connect(url, headers=headers, auth=auth, transports=transports, namespaces=[namespace], - socketio_path=socketio_path) + socketio_path=socketio_path, + wait_timeout=wait_timeout) @property def sid(self): @@ -87,7 +91,7 @@ class SimpleClient: The session ID is not guaranteed to remain constant throughout the life of the connection, as reconnections can cause it to change. """ - return self.client.sid if self.client else None + return self.client.get_sid(self.namespace) if self.client else None @property def transport(self): diff --git a/tests/async/test_admin.py b/tests/async/test_admin.py new file mode 100644 index 0000000..95b8fef --- /dev/null +++ b/tests/async/test_admin.py @@ -0,0 +1,311 @@ +from functools import wraps +import threading +import time +from unittest import mock +import unittest +import pytest +try: + from engineio.async_socket import AsyncSocket as EngineIOSocket +except ImportError: + from engineio.asyncio_socket import AsyncSocket as EngineIOSocket +import socketio +from socketio.exceptions import ConnectionError +from tests.asyncio_web_server import SocketIOWebServer +from .helpers import AsyncMock + + +def with_instrumented_server(auth=False, **ikwargs): + """This decorator can be applied to test functions or methods so that they + run with a Socket.IO server that has been instrumented for the official + Admin UI project. The arguments passed to the decorator are passed directly + to the ``instrument()`` method of the server. + """ + def decorator(f): + @wraps(f) + def wrapped(self, *args, **kwargs): + sio = socketio.AsyncServer(async_mode='asgi') + + @sio.event + async def enter_room(sid, data): + await sio.enter_room(sid, data) + + @sio.event + async def emit(sid, event): + await sio.emit(event, skip_sid=sid) + + @sio.event(namespace='/foo') + def connect(sid, environ, auth): + pass + + async def shutdown(): + await instrumented_server.shutdown() + await sio.shutdown() + + if 'server_stats_interval' not in ikwargs: + ikwargs['server_stats_interval'] = 0.25 + + instrumented_server = sio.instrument(auth=auth, **ikwargs) + server = SocketIOWebServer(sio, on_shutdown=shutdown) + server.start() + + # import logging + # logging.getLogger('engineio.client').setLevel(logging.DEBUG) + # logging.getLogger('socketio.client').setLevel(logging.DEBUG) + + original_schedule_ping = EngineIOSocket.schedule_ping + EngineIOSocket.schedule_ping = mock.MagicMock() + + try: + ret = f(self, instrumented_server, *args, **kwargs) + finally: + server.stop() + instrumented_server.uninstrument() + + EngineIOSocket.schedule_ping = original_schedule_ping + + # import logging + # logging.getLogger('engineio.client').setLevel(logging.NOTSET) + # logging.getLogger('socketio.client').setLevel(logging.NOTSET) + + return ret + return wrapped + return decorator + + +def _custom_auth(auth): + return auth == {'foo': 'bar'} + + +async def _async_custom_auth(auth): + return auth == {'foo': 'bar'} + + +class TestAsyncAdmin(unittest.TestCase): + def setUp(self): + print('threads at start:', threading.enumerate()) + self.thread_count = threading.active_count() + + def tearDown(self): + print('threads at end:', threading.enumerate()) + assert self.thread_count == threading.active_count() + + def test_missing_auth(self): + sio = socketio.AsyncServer(async_mode='asgi') + with pytest.raises(ValueError): + sio.instrument() + + @with_instrumented_server(auth=False) + def test_admin_connect_with_no_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin') + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + + @with_instrumented_server(auth={'foo': 'bar'}) + def test_admin_connect_with_dict_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect( + 'http://localhost:8900', namespace='/admin', + auth={'foo': 'baz'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect( + 'http://localhost:8900', namespace='/admin') + + @with_instrumented_server(auth=[{'foo': 'bar'}, + {'u': 'admin', 'p': 'secret'}]) + def test_admin_connect_with_list_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'u': 'admin', 'p': 'secret'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin', auth={'foo': 'baz'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin') + + @with_instrumented_server(auth=_custom_auth) + def test_admin_connect_with_function_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin', auth={'foo': 'baz'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin') + + @with_instrumented_server(auth=_async_custom_auth) + def test_admin_connect_with_async_function_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin', auth={'foo': 'baz'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin') + + @with_instrumented_server() + def test_admin_connect_only_admin(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin') + sid = admin_client.sid + expected = ['config', 'all_sockets', 'server_stats'] + events = {} + while expected: + data = admin_client.receive(timeout=5) + if data[0] in expected: + events[data[0]] = data[1] + expected.remove(data[0]) + + assert 'supportedFeatures' in events['config'] + assert 'ALL_EVENTS' in events['config']['supportedFeatures'] + assert 'AGGREGATED_EVENTS' in events['config']['supportedFeatures'] + assert 'EMIT' in events['config']['supportedFeatures'] + assert len(events['all_sockets']) == 1 + assert events['all_sockets'][0]['id'] == sid + assert events['all_sockets'][0]['rooms'] == [sid] + assert events['server_stats']['clientsCount'] == 1 + assert events['server_stats']['pollingClientsCount'] == 0 + assert len(events['server_stats']['namespaces']) == 3 + assert {'name': '/', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/foo', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/admin', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + + @with_instrumented_server() + def test_admin_connect_with_others(self, isvr): + with socketio.SimpleClient() as client1, \ + socketio.SimpleClient() as client2, \ + socketio.SimpleClient() as client3, \ + socketio.SimpleClient() as admin_client: + client1.connect('http://localhost:8900') + client1.emit('enter_room', 'room') + sid1 = client1.sid + + saved_check_for_upgrade = isvr._check_for_upgrade + isvr._check_for_upgrade = AsyncMock() + client2.connect('http://localhost:8900', namespace='/foo', + transports=['polling']) + sid2 = client2.sid + isvr._check_for_upgrade = saved_check_for_upgrade + + client3.connect('http://localhost:8900', namespace='/admin') + sid3 = client3.sid + + admin_client.connect('http://localhost:8900', namespace='/admin') + sid = admin_client.sid + expected = ['config', 'all_sockets', 'server_stats'] + events = {} + while expected: + data = admin_client.receive(timeout=5) + if data[0] in expected: + events[data[0]] = data[1] + expected.remove(data[0]) + + assert 'supportedFeatures' in events['config'] + assert 'ALL_EVENTS' in events['config']['supportedFeatures'] + assert 'AGGREGATED_EVENTS' in events['config']['supportedFeatures'] + assert 'EMIT' in events['config']['supportedFeatures'] + assert len(events['all_sockets']) == 4 + assert events['server_stats']['clientsCount'] == 4 + assert events['server_stats']['pollingClientsCount'] == 1 + assert len(events['server_stats']['namespaces']) == 3 + assert {'name': '/', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + assert {'name': '/foo', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + assert {'name': '/admin', 'socketsCount': 2} in \ + events['server_stats']['namespaces'] + + for socket in events['all_sockets']: + if socket['id'] == sid: + assert socket['rooms'] == [sid] + elif socket['id'] == sid1: + assert socket['rooms'] == [sid1, 'room'] + elif socket['id'] == sid2: + assert socket['rooms'] == [sid2] + elif socket['id'] == sid3: + assert socket['rooms'] == [sid3] + + @with_instrumented_server(mode='production', read_only=True) + def test_admin_connect_production(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin') + expected = ['config', 'server_stats'] + events = {} + while expected: + data = admin_client.receive(timeout=5) + if data[0] in expected: + events[data[0]] = data[1] + expected.remove(data[0]) + + assert 'supportedFeatures' in events['config'] + assert 'ALL_EVENTS' not in events['config']['supportedFeatures'] + assert 'AGGREGATED_EVENTS' in events['config']['supportedFeatures'] + assert 'EMIT' not in events['config']['supportedFeatures'] + assert events['server_stats']['clientsCount'] == 1 + assert events['server_stats']['pollingClientsCount'] == 0 + assert len(events['server_stats']['namespaces']) == 3 + assert {'name': '/', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/foo', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/admin', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + + @with_instrumented_server() + def test_admin_features(self, isvr): + with socketio.SimpleClient() as client1, \ + socketio.SimpleClient() as client2, \ + socketio.SimpleClient() as admin_client: + client1.connect('http://localhost:8900') + client2.connect('http://localhost:8900') + admin_client.connect('http://localhost:8900', namespace='/admin') + + # emit from admin + admin_client.emit( + 'emit', ('/', client1.sid, 'foo', {'bar': 'baz'}, 'extra')) + data = client1.receive(timeout=5) + assert data == ['foo', {'bar': 'baz'}, 'extra'] + + # emit from regular client + client1.emit('emit', 'foo') + data = client2.receive(timeout=5) + assert data == ['foo'] + + # join and leave + admin_client.emit('join', ('/', 'room', client1.sid)) + admin_client.emit( + 'emit', ('/', 'room', 'foo', {'bar': 'baz'})) + data = client1.receive(timeout=5) + assert data == ['foo', {'bar': 'baz'}] + admin_client.emit('leave', ('/', 'room')) + + # disconnect + admin_client.emit('_disconnect', ('/', False, client1.sid)) + for _ in range(10): + if not client1.connected: + break + time.sleep(0.2) + assert not client1.connected diff --git a/tests/async/test_manager.py b/tests/async/test_manager.py index 7cfb46c..0b2fc93 100644 --- a/tests/async/test_manager.py +++ b/tests/async/test_manager.py @@ -27,7 +27,7 @@ class TestAsyncManager(unittest.TestCase): self.bm.initialize() def test_connect(self): - sid = self.bm.connect('123', '/foo') + sid = _run(self.bm.connect('123', '/foo')) assert None in self.bm.rooms['/foo'] assert sid in self.bm.rooms['/foo'] assert sid in self.bm.rooms['/foo'][None] @@ -37,8 +37,8 @@ class TestAsyncManager(unittest.TestCase): assert self.bm.sid_from_eio_sid('123', '/foo') == sid def test_pre_disconnect(self): - sid1 = self.bm.connect('123', '/foo') - sid2 = self.bm.connect('456', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) + sid2 = _run(self.bm.connect('456', '/foo')) assert self.bm.is_connected(sid1, '/foo') assert self.bm.pre_disconnect(sid1, '/foo') == '123' assert self.bm.pending_disconnect == {'/foo': [sid1]} @@ -52,8 +52,8 @@ class TestAsyncManager(unittest.TestCase): assert self.bm.pending_disconnect == {} def test_disconnect(self): - sid1 = self.bm.connect('123', '/foo') - sid2 = self.bm.connect('456', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) + sid2 = _run(self.bm.connect('456', '/foo')) _run(self.bm.enter_room(sid1, '/foo', 'bar')) _run(self.bm.enter_room(sid2, '/foo', 'baz')) _run(self.bm.disconnect(sid1, '/foo')) @@ -62,10 +62,10 @@ class TestAsyncManager(unittest.TestCase): assert dict(self.bm.rooms['/foo']['baz']) == {sid2: '456'} def test_disconnect_default_namespace(self): - sid1 = self.bm.connect('123', '/') - sid2 = self.bm.connect('123', '/foo') - sid3 = self.bm.connect('456', '/') - sid4 = self.bm.connect('456', '/foo') + sid1 = _run(self.bm.connect('123', '/')) + sid2 = _run(self.bm.connect('123', '/foo')) + sid3 = _run(self.bm.connect('456', '/')) + sid4 = _run(self.bm.connect('456', '/foo')) assert self.bm.is_connected(sid1, '/') assert self.bm.is_connected(sid2, '/foo') assert not self.bm.is_connected(sid2, '/') @@ -81,10 +81,10 @@ class TestAsyncManager(unittest.TestCase): assert dict(self.bm.rooms['/foo'][sid4]) == {sid4: '456'} def test_disconnect_twice(self): - sid1 = self.bm.connect('123', '/') - sid2 = self.bm.connect('123', '/foo') - sid3 = self.bm.connect('456', '/') - sid4 = self.bm.connect('456', '/foo') + sid1 = _run(self.bm.connect('123', '/')) + sid2 = _run(self.bm.connect('123', '/foo')) + sid3 = _run(self.bm.connect('456', '/')) + sid4 = _run(self.bm.connect('456', '/foo')) _run(self.bm.disconnect(sid1, '/')) _run(self.bm.disconnect(sid2, '/foo')) _run(self.bm.disconnect(sid1, '/')) @@ -95,8 +95,8 @@ class TestAsyncManager(unittest.TestCase): assert dict(self.bm.rooms['/foo'][sid4]) == {sid4: '456'} def test_disconnect_all(self): - sid1 = self.bm.connect('123', '/foo') - sid2 = self.bm.connect('456', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) + sid2 = _run(self.bm.connect('456', '/foo')) _run(self.bm.enter_room(sid1, '/foo', 'bar')) _run(self.bm.enter_room(sid2, '/foo', 'baz')) _run(self.bm.disconnect(sid1, '/foo')) @@ -104,9 +104,9 @@ class TestAsyncManager(unittest.TestCase): assert self.bm.rooms == {} def test_disconnect_with_callbacks(self): - sid1 = self.bm.connect('123', '/') - sid2 = self.bm.connect('123', '/foo') - sid3 = self.bm.connect('456', '/foo') + sid1 = _run(self.bm.connect('123', '/')) + sid2 = _run(self.bm.connect('123', '/foo')) + sid3 = _run(self.bm.connect('456', '/foo')) self.bm._generate_ack_id(sid1, 'f') self.bm._generate_ack_id(sid2, 'g') self.bm._generate_ack_id(sid3, 'h') @@ -117,8 +117,8 @@ class TestAsyncManager(unittest.TestCase): assert sid3 in self.bm.callbacks def test_trigger_sync_callback(self): - sid1 = self.bm.connect('123', '/') - sid2 = self.bm.connect('123', '/foo') + sid1 = _run(self.bm.connect('123', '/')) + sid2 = _run(self.bm.connect('123', '/foo')) cb = mock.MagicMock() id1 = self.bm._generate_ack_id(sid1, cb) id2 = self.bm._generate_ack_id(sid2, cb) @@ -129,8 +129,8 @@ class TestAsyncManager(unittest.TestCase): cb.assert_any_call('bar', 'baz') def test_trigger_async_callback(self): - sid1 = self.bm.connect('123', '/') - sid2 = self.bm.connect('123', '/foo') + sid1 = _run(self.bm.connect('123', '/')) + sid2 = _run(self.bm.connect('123', '/foo')) cb = AsyncMock() id1 = self.bm._generate_ack_id(sid1, cb) id2 = self.bm._generate_ack_id(sid2, cb) @@ -141,7 +141,7 @@ class TestAsyncManager(unittest.TestCase): cb.mock.assert_any_call('bar', 'baz') def test_invalid_callback(self): - sid = self.bm.connect('123', '/') + sid = _run(self.bm.connect('123', '/')) cb = mock.MagicMock() id = self.bm._generate_ack_id(sid, cb) @@ -152,17 +152,17 @@ class TestAsyncManager(unittest.TestCase): def test_get_namespaces(self): assert list(self.bm.get_namespaces()) == [] - self.bm.connect('123', '/') - self.bm.connect('123', '/foo') + _run(self.bm.connect('123', '/')) + _run(self.bm.connect('123', '/foo')) namespaces = list(self.bm.get_namespaces()) assert len(namespaces) == 2 assert '/' in namespaces assert '/foo' in namespaces def test_get_participants(self): - sid1 = self.bm.connect('123', '/') - sid2 = self.bm.connect('456', '/') - sid3 = self.bm.connect('789', '/') + sid1 = _run(self.bm.connect('123', '/')) + sid2 = _run(self.bm.connect('456', '/')) + sid3 = _run(self.bm.connect('789', '/')) _run(self.bm.disconnect(sid3, '/')) assert sid3 not in self.bm.rooms['/'][None] participants = list(self.bm.get_participants('/', None)) @@ -172,7 +172,7 @@ class TestAsyncManager(unittest.TestCase): assert (sid3, '789') not in participants def test_leave_invalid_room(self): - sid = self.bm.connect('123', '/foo') + sid = _run(self.bm.connect('123', '/foo')) _run(self.bm.leave_room(sid, '/foo', 'baz')) _run(self.bm.leave_room(sid, '/bar', 'baz')) @@ -181,9 +181,9 @@ class TestAsyncManager(unittest.TestCase): assert [] == rooms def test_close_room(self): - sid = self.bm.connect('123', '/foo') - self.bm.connect('456', '/foo') - self.bm.connect('789', '/foo') + sid = _run(self.bm.connect('123', '/foo')) + _run(self.bm.connect('456', '/foo')) + _run(self.bm.connect('789', '/foo')) _run(self.bm.enter_room(sid, '/foo', 'bar')) _run(self.bm.enter_room(sid, '/foo', 'bar')) _run(self.bm.close_room('bar', '/foo')) @@ -195,7 +195,7 @@ class TestAsyncManager(unittest.TestCase): self.bm.close_room('bar', '/foo') def test_rooms(self): - sid = self.bm.connect('123', '/foo') + sid = _run(self.bm.connect('123', '/foo')) _run(self.bm.enter_room(sid, '/foo', 'bar')) r = self.bm.get_rooms(sid, '/foo') assert len(r) == 2 @@ -203,8 +203,8 @@ class TestAsyncManager(unittest.TestCase): assert 'bar' in r def test_emit_to_sid(self): - sid = self.bm.connect('123', '/foo') - self.bm.connect('456', '/foo') + sid = _run(self.bm.connect('123', '/foo')) + _run(self.bm.connect('456', '/foo')) _run( self.bm.emit( 'my event', {'foo': 'bar'}, namespace='/foo', room=sid @@ -217,11 +217,11 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event",{"foo":"bar"}]' def test_emit_to_room(self): - sid1 = self.bm.connect('123', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) _run(self.bm.enter_room(sid1, '/foo', 'bar')) - sid2 = self.bm.connect('456', '/foo') + sid2 = _run(self.bm.connect('456', '/foo')) _run(self.bm.enter_room(sid2, '/foo', 'bar')) - self.bm.connect('789', '/foo') + _run(self.bm.connect('789', '/foo')) _run( self.bm.emit( 'my event', {'foo': 'bar'}, namespace='/foo', room='bar' @@ -238,12 +238,12 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event",{"foo":"bar"}]' def test_emit_to_rooms(self): - sid1 = self.bm.connect('123', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) _run(self.bm.enter_room(sid1, '/foo', 'bar')) - sid2 = self.bm.connect('456', '/foo') + sid2 = _run(self.bm.connect('456', '/foo')) _run(self.bm.enter_room(sid2, '/foo', 'bar')) _run(self.bm.enter_room(sid2, '/foo', 'baz')) - sid3 = self.bm.connect('789', '/foo') + sid3 = _run(self.bm.connect('789', '/foo')) _run(self.bm.enter_room(sid3, '/foo', 'baz')) _run( self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo', @@ -264,12 +264,12 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event",{"foo":"bar"}]' def test_emit_to_all(self): - sid1 = self.bm.connect('123', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) _run(self.bm.enter_room(sid1, '/foo', 'bar')) - sid2 = self.bm.connect('456', '/foo') + sid2 = _run(self.bm.connect('456', '/foo')) _run(self.bm.enter_room(sid2, '/foo', 'bar')) - self.bm.connect('789', '/foo') - self.bm.connect('abc', '/bar') + _run(self.bm.connect('789', '/foo')) + _run(self.bm.connect('abc', '/bar')) _run(self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo')) assert self.bm.server._send_eio_packet.mock.call_count == 3 assert self.bm.server._send_eio_packet.mock.call_args_list[0][0][0] \ @@ -286,12 +286,12 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event",{"foo":"bar"}]' def test_emit_to_all_skip_one(self): - sid1 = self.bm.connect('123', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) _run(self.bm.enter_room(sid1, '/foo', 'bar')) - sid2 = self.bm.connect('456', '/foo') + sid2 = _run(self.bm.connect('456', '/foo')) _run(self.bm.enter_room(sid2, '/foo', 'bar')) - self.bm.connect('789', '/foo') - self.bm.connect('abc', '/bar') + _run(self.bm.connect('789', '/foo')) + _run(self.bm.connect('abc', '/bar')) _run( self.bm.emit( 'my event', {'foo': 'bar'}, namespace='/foo', skip_sid=sid2 @@ -308,12 +308,12 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event",{"foo":"bar"}]' def test_emit_to_all_skip_two(self): - sid1 = self.bm.connect('123', '/foo') + sid1 = _run(self.bm.connect('123', '/foo')) _run(self.bm.enter_room(sid1, '/foo', 'bar')) - sid2 = self.bm.connect('456', '/foo') + sid2 = _run(self.bm.connect('456', '/foo')) _run(self.bm.enter_room(sid2, '/foo', 'bar')) - sid3 = self.bm.connect('789', '/foo') - self.bm.connect('abc', '/bar') + sid3 = _run(self.bm.connect('789', '/foo')) + _run(self.bm.connect('abc', '/bar')) _run( self.bm.emit( 'my event', @@ -329,7 +329,7 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event",{"foo":"bar"}]' def test_emit_with_callback(self): - sid = self.bm.connect('123', '/foo') + sid = _run(self.bm.connect('123', '/foo')) self.bm._generate_ack_id = mock.MagicMock() self.bm._generate_ack_id.return_value = 11 _run( @@ -353,7 +353,7 @@ class TestAsyncManager(unittest.TestCase): _run(self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo')) def test_emit_with_tuple(self): - sid = self.bm.connect('123', '/foo') + sid = _run(self.bm.connect('123', '/foo')) _run( self.bm.emit( 'my event', ('foo', 'bar'), namespace='/foo', room=sid @@ -366,7 +366,7 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event","foo","bar"]' def test_emit_with_list(self): - sid = self.bm.connect('123', '/foo') + sid = _run(self.bm.connect('123', '/foo')) _run( self.bm.emit( 'my event', ['foo', 'bar'], namespace='/foo', room=sid @@ -379,7 +379,7 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event",["foo","bar"]]' def test_emit_with_none(self): - sid = self.bm.connect('123', '/foo') + sid = _run(self.bm.connect('123', '/foo')) _run( self.bm.emit( 'my event', None, namespace='/foo', room=sid @@ -392,7 +392,7 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event"]' def test_emit_binary(self): - sid = self.bm.connect('123', '/') + sid = _run(self.bm.connect('123', '/')) _run( self.bm.emit( u'my event', b'my binary data', namespace='/', room=sid diff --git a/tests/async/test_pubsub_manager.py b/tests/async/test_pubsub_manager.py index 3b4d0a9..48a71aa 100644 --- a/tests/async/test_pubsub_manager.py +++ b/tests/async/test_pubsub_manager.py @@ -145,7 +145,7 @@ class TestAsyncPubSubManager(unittest.TestCase): _run(self.pm.emit('foo', 'bar', callback='cb')) def test_emit_with_ignore_queue(self): - sid = self.pm.connect('123', '/') + sid = _run(self.pm.connect('123', '/')) _run( self.pm.emit( 'foo', 'bar', room=sid, namespace='/', ignore_queue=True @@ -159,7 +159,7 @@ class TestAsyncPubSubManager(unittest.TestCase): assert pkt.encode() == '42["foo","bar"]' def test_can_disconnect(self): - sid = self.pm.connect('123', '/') + sid = _run(self.pm.connect('123', '/')) assert _run(self.pm.can_disconnect(sid, '/')) is True _run(self.pm.can_disconnect(sid, '/foo')) self.pm._publish.mock.assert_called_once_with( @@ -175,14 +175,14 @@ class TestAsyncPubSubManager(unittest.TestCase): ) def test_disconnect_ignore_queue(self): - sid = self.pm.connect('123', '/') + sid = _run(self.pm.connect('123', '/')) self.pm.pre_disconnect(sid, '/') _run(self.pm.disconnect(sid, '/', ignore_queue=True)) self.pm._publish.mock.assert_not_called() assert self.pm.is_connected(sid, '/') is False def test_enter_room(self): - sid = self.pm.connect('123', '/') + sid = _run(self.pm.connect('123', '/')) _run(self.pm.enter_room(sid, '/', 'foo')) _run(self.pm.enter_room('456', '/', 'foo')) assert sid in self.pm.rooms['/']['foo'] @@ -193,7 +193,7 @@ class TestAsyncPubSubManager(unittest.TestCase): ) def test_leave_room(self): - sid = self.pm.connect('123', '/') + sid = _run(self.pm.connect('123', '/')) _run(self.pm.leave_room(sid, '/', 'foo')) _run(self.pm.leave_room('456', '/', 'foo')) assert 'foo' not in self.pm.rooms['/'] @@ -435,7 +435,7 @@ class TestAsyncPubSubManager(unittest.TestCase): ) def test_handle_enter_room(self): - sid = self.pm.connect('123', '/') + sid = _run(self.pm.connect('123', '/')) with mock.patch.object( async_manager.AsyncManager, 'enter_room', new=AsyncMock() ) as super_enter_room: @@ -456,7 +456,7 @@ class TestAsyncPubSubManager(unittest.TestCase): ) def test_handle_leave_room(self): - sid = self.pm.connect('123', '/') + sid = _run(self.pm.connect('123', '/')) with mock.patch.object( async_manager.AsyncManager, 'leave_room', new=AsyncMock() ) as super_leave_room: diff --git a/tests/async/test_server.py b/tests/async/test_server.py index e324c39..2f84b5f 100644 --- a/tests/async/test_server.py +++ b/tests/async/test_server.py @@ -596,7 +596,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) handler = AsyncMock() catchall_handler = AsyncMock() s.on('msg', handler) @@ -610,7 +610,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_with_namespace(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/foo') + sid = _run(s.manager.connect('123', '/foo')) handler = mock.MagicMock() catchall_handler = mock.MagicMock() s.on('msg', handler, namespace='/foo') @@ -624,7 +624,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_with_disconnected_namespace(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - s.manager.connect('123', '/foo') + _run(s.manager.connect('123', '/foo')) handler = mock.MagicMock() s.on('my message', handler, namespace='/bar') _run(s._handle_eio_message('123', '2/bar,["my message","a","b","c"]')) @@ -633,7 +633,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_binary(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) handler = mock.MagicMock() s.on('my message', handler) _run( @@ -652,7 +652,7 @@ class TestAsyncServer(unittest.TestCase): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) s.manager.trigger_callback = AsyncMock() - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) _run( s._handle_eio_message( '123', @@ -667,7 +667,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_with_ack(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) handler = mock.MagicMock(return_value='foo') s.on('my message', handler) _run(s._handle_eio_message('123', '21000["my message","foo"]')) @@ -679,7 +679,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_unknown_event_with_ack(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - s.manager.connect('123', '/') + _run(s.manager.connect('123', '/')) handler = mock.MagicMock(return_value='foo') s.on('my message', handler) _run(s._handle_eio_message('123', '21000["another message","foo"]')) @@ -688,7 +688,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_with_ack_none(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) handler = mock.MagicMock(return_value=None) s.on('my message', handler) _run(s._handle_eio_message('123', '21000["my message","foo"]')) @@ -698,7 +698,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_with_ack_tuple(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) handler = mock.MagicMock(return_value=(1, '2', True)) s.on('my message', handler) _run(s._handle_eio_message('123', '21000["my message","a","b","c"]')) @@ -710,7 +710,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_with_ack_list(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) handler = mock.MagicMock(return_value=[1, '2', True]) s.on('my message', handler) _run(s._handle_eio_message('123', '21000["my message","a","b","c"]')) @@ -722,7 +722,7 @@ class TestAsyncServer(unittest.TestCase): def test_handle_event_with_ack_binary(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) - sid = s.manager.connect('123', '/') + sid = _run(s.manager.connect('123', '/')) handler = mock.MagicMock(return_value=b'foo') s.on('my message', handler) _run(s._handle_eio_message('123', '21000["my message","foo"]')) @@ -973,7 +973,7 @@ class TestAsyncServer(unittest.TestCase): def test_async_handlers(self, eio): s = async_server.AsyncServer(async_handlers=True) - s.manager.connect('123', '/') + _run(s.manager.connect('123', '/')) _run(s._handle_eio_message('123', '2["my message","a","b","c"]')) s.eio.start_background_task.assert_called_once_with( s._handle_event_internal, diff --git a/tests/async/test_simple_client.py b/tests/async/test_simple_client.py index 397ae98..f8996bd 100644 --- a/tests/async/test_simple_client.py +++ b/tests/async/test_simple_client.py @@ -24,12 +24,13 @@ class TestAsyncAsyncSimpleClient(unittest.TestCase): mock_client.return_value.connect = AsyncMock() _run(client.connect('url', headers='h', auth='a', transports='t', - namespace='n', socketio_path='s')) + namespace='n', socketio_path='s', + wait_timeout='w')) mock_client.assert_called_once_with(123, a='b') assert client.client == mock_client() mock_client().connect.mock.assert_called_once_with( 'url', headers='h', auth='a', transports='t', - namespaces=['n'], socketio_path='s') + namespaces=['n'], socketio_path='s', wait_timeout='w') mock_client().event.call_count == 3 mock_client().on.called_once_with('*') assert client.namespace == 'n' @@ -44,12 +45,12 @@ class TestAsyncAsyncSimpleClient(unittest.TestCase): await client.connect('url', headers='h', auth='a', transports='t', namespace='n', - socketio_path='s') + socketio_path='s', wait_timeout='w') mock_client.assert_called_once_with(123, a='b') assert client.client == mock_client() mock_client().connect.mock.assert_called_once_with( 'url', headers='h', auth='a', transports='t', - namespaces=['n'], socketio_path='s') + namespaces=['n'], socketio_path='s', wait_timeout='w') mock_client().event.call_count == 3 mock_client().on.called_once_with('*') assert client.namespace == 'n' @@ -67,7 +68,8 @@ class TestAsyncAsyncSimpleClient(unittest.TestCase): def test_properties(self): client = AsyncSimpleClient() - client.client = mock.MagicMock(sid='sid', transport='websocket') + client.client = mock.MagicMock(transport='websocket') + client.client.get_sid.return_value = 'sid' client.connected_event.set() client.connected = True diff --git a/tests/asyncio_web_server.py b/tests/asyncio_web_server.py new file mode 100644 index 0000000..8b2046c --- /dev/null +++ b/tests/asyncio_web_server.py @@ -0,0 +1,57 @@ +import requests +import threading +import time +import uvicorn +import socketio + + +class SocketIOWebServer: + """A simple web server used for running Socket.IO servers in tests. + + :param sio: a Socket.IO server instance. + + Note 1: This class is not production-ready and is intended for testing. + Note 2: This class only supports the "asgi" async_mode. + """ + def __init__(self, sio, on_shutdown=None): + if sio.async_mode != 'asgi': + raise ValueError('The async_mode must be "asgi"') + + async def http_app(scope, receive, send): + await send({'type': 'http.response.start', + 'status': 200, + 'headers': [('Content-Type', 'text/plain')]}) + await send({'type': 'http.response.body', + 'body': b'OK'}) + + self.sio = sio + self.app = socketio.ASGIApp(sio, http_app, on_shutdown=on_shutdown) + self.httpd = None + self.thread = None + + def start(self, port=8900): + """Start the web server. + + :param port: the port to listen on. Defaults to 8900. + + The server is started in a background thread. + """ + self.httpd = uvicorn.Server(config=uvicorn.Config(self.app, port=port)) + self.thread = threading.Thread(target=self.httpd.run) + self.thread.start() + + # wait for the server to start + while True: + try: + r = requests.get(f'http://localhost:{port}/') + r.raise_for_status() + if r.text == 'OK': + break + except: + time.sleep(0.1) + + def stop(self): + """Stop the web server.""" + self.httpd.should_exit = True + self.thread.join() + self.thread = None diff --git a/tests/common/test_admin.py b/tests/common/test_admin.py new file mode 100644 index 0000000..c4e2ec8 --- /dev/null +++ b/tests/common/test_admin.py @@ -0,0 +1,286 @@ +from functools import wraps +import threading +import time +from unittest import mock +import unittest +import pytest +from engineio.socket import Socket as EngineIOSocket +import socketio +from socketio.exceptions import ConnectionError +from tests.web_server import SocketIOWebServer + + +def with_instrumented_server(auth=False, **ikwargs): + """This decorator can be applied to test functions or methods so that they + run with a Socket.IO server that has been instrumented for the official + Admin UI project. The arguments passed to the decorator are passed directly + to the ``instrument()`` method of the server. + """ + def decorator(f): + @wraps(f) + def wrapped(self, *args, **kwargs): + sio = socketio.Server(async_mode='threading') + + @sio.event + def enter_room(sid, data): + sio.enter_room(sid, data) + + @sio.event + def emit(sid, event): + sio.emit(event, skip_sid=sid) + + @sio.event(namespace='/foo') + def connect(sid, environ, auth): + pass + + if 'server_stats_interval' not in ikwargs: + ikwargs['server_stats_interval'] = 0.25 + + instrumented_server = sio.instrument(auth=auth, **ikwargs) + server = SocketIOWebServer(sio) + server.start() + + # import logging + # logging.getLogger('engineio.client').setLevel(logging.DEBUG) + # logging.getLogger('socketio.client').setLevel(logging.DEBUG) + + original_schedule_ping = EngineIOSocket.schedule_ping + EngineIOSocket.schedule_ping = mock.MagicMock() + + try: + ret = f(self, instrumented_server, *args, **kwargs) + finally: + server.stop() + instrumented_server.shutdown() + instrumented_server.uninstrument() + + EngineIOSocket.schedule_ping = original_schedule_ping + + # import logging + # logging.getLogger('engineio.client').setLevel(logging.NOTSET) + # logging.getLogger('socketio.client').setLevel(logging.NOTSET) + + return ret + return wrapped + return decorator + + +def _custom_auth(auth): + return auth == {'foo': 'bar'} + + +class TestAdmin(unittest.TestCase): + def setUp(self): + print('threads at start:', threading.enumerate()) + self.thread_count = threading.active_count() + + def tearDown(self): + print('threads at end:', threading.enumerate()) + assert self.thread_count == threading.active_count() + + def test_missing_auth(self): + sio = socketio.Server(async_mode='threading') + with pytest.raises(ValueError): + sio.instrument() + + @with_instrumented_server(auth=False) + def test_admin_connect_with_no_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin') + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + + @with_instrumented_server(auth={'foo': 'bar'}) + def test_admin_connect_with_dict_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect( + 'http://localhost:8900', namespace='/admin', + auth={'foo': 'baz'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect( + 'http://localhost:8900', namespace='/admin') + + @with_instrumented_server(auth=[{'foo': 'bar'}, + {'u': 'admin', 'p': 'secret'}]) + def test_admin_connect_with_list_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'u': 'admin', 'p': 'secret'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin', auth={'foo': 'baz'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin') + + @with_instrumented_server(auth=_custom_auth) + def test_admin_connect_with_function_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin', auth={'foo': 'baz'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin') + + @with_instrumented_server() + def test_admin_connect_only_admin(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin') + sid = admin_client.sid + expected = ['config', 'all_sockets', 'server_stats'] + events = {} + while expected: + data = admin_client.receive(timeout=5) + if data[0] in expected: + events[data[0]] = data[1] + expected.remove(data[0]) + + assert 'supportedFeatures' in events['config'] + assert 'ALL_EVENTS' in events['config']['supportedFeatures'] + assert 'AGGREGATED_EVENTS' in events['config']['supportedFeatures'] + assert 'EMIT' in events['config']['supportedFeatures'] + assert len(events['all_sockets']) == 1 + assert events['all_sockets'][0]['id'] == sid + assert events['all_sockets'][0]['rooms'] == [sid] + assert events['server_stats']['clientsCount'] == 1 + assert events['server_stats']['pollingClientsCount'] == 0 + assert len(events['server_stats']['namespaces']) == 3 + assert {'name': '/', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/foo', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/admin', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + + @with_instrumented_server() + def test_admin_connect_with_others(self, isvr): + with socketio.SimpleClient() as client1, \ + socketio.SimpleClient() as client2, \ + socketio.SimpleClient() as client3, \ + socketio.SimpleClient() as admin_client: + client1.connect('http://localhost:8900') + client1.emit('enter_room', 'room') + sid1 = client1.sid + + saved_check_for_upgrade = isvr._check_for_upgrade + isvr._check_for_upgrade = mock.MagicMock() + client2.connect('http://localhost:8900', namespace='/foo', + transports=['polling']) + sid2 = client2.sid + isvr._check_for_upgrade = saved_check_for_upgrade + + client3.connect('http://localhost:8900', namespace='/admin') + sid3 = client3.sid + + admin_client.connect('http://localhost:8900', namespace='/admin') + sid = admin_client.sid + expected = ['config', 'all_sockets', 'server_stats'] + events = {} + while expected: + data = admin_client.receive(timeout=5) + if data[0] in expected: + events[data[0]] = data[1] + expected.remove(data[0]) + + assert 'supportedFeatures' in events['config'] + assert 'ALL_EVENTS' in events['config']['supportedFeatures'] + assert 'AGGREGATED_EVENTS' in events['config']['supportedFeatures'] + assert 'EMIT' in events['config']['supportedFeatures'] + assert len(events['all_sockets']) == 4 + assert events['server_stats']['clientsCount'] == 4 + assert events['server_stats']['pollingClientsCount'] == 1 + assert len(events['server_stats']['namespaces']) == 3 + assert {'name': '/', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + assert {'name': '/foo', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + assert {'name': '/admin', 'socketsCount': 2} in \ + events['server_stats']['namespaces'] + + for socket in events['all_sockets']: + if socket['id'] == sid: + assert socket['rooms'] == [sid] + elif socket['id'] == sid1: + assert socket['rooms'] == [sid1, 'room'] + elif socket['id'] == sid2: + assert socket['rooms'] == [sid2] + elif socket['id'] == sid3: + assert socket['rooms'] == [sid3] + + @with_instrumented_server(mode='production', read_only=True) + def test_admin_connect_production(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin') + expected = ['config', 'server_stats'] + events = {} + while expected: + data = admin_client.receive(timeout=5) + if data[0] in expected: + events[data[0]] = data[1] + expected.remove(data[0]) + + assert 'supportedFeatures' in events['config'] + assert 'ALL_EVENTS' not in events['config']['supportedFeatures'] + assert 'AGGREGATED_EVENTS' in events['config']['supportedFeatures'] + assert 'EMIT' not in events['config']['supportedFeatures'] + assert events['server_stats']['clientsCount'] == 1 + assert events['server_stats']['pollingClientsCount'] == 0 + assert len(events['server_stats']['namespaces']) == 3 + assert {'name': '/', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/foo', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/admin', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + + @with_instrumented_server() + def test_admin_features(self, isvr): + with socketio.SimpleClient() as client1, \ + socketio.SimpleClient() as client2, \ + socketio.SimpleClient() as admin_client: + client1.connect('http://localhost:8900') + client2.connect('http://localhost:8900') + admin_client.connect('http://localhost:8900', namespace='/admin') + + # emit from admin + admin_client.emit( + 'emit', ('/', client1.sid, 'foo', {'bar': 'baz'}, 'extra')) + data = client1.receive(timeout=5) + assert data == ['foo', {'bar': 'baz'}, 'extra'] + + # emit from regular client + client1.emit('emit', 'foo') + data = client2.receive(timeout=5) + assert data == ['foo'] + + # join and leave + admin_client.emit('join', ('/', 'room', client1.sid)) + admin_client.emit( + 'emit', ('/', 'room', 'foo', {'bar': 'baz'})) + data = client1.receive(timeout=5) + assert data == ['foo', {'bar': 'baz'}] + admin_client.emit('leave', ('/', 'room')) + + # disconnect + admin_client.emit('_disconnect', ('/', False, client1.sid)) + for _ in range(10): + if not client1.connected: + break + time.sleep(0.2) + assert not client1.connected diff --git a/tests/common/test_simple_client.py b/tests/common/test_simple_client.py index 2a0b7b7..4069042 100644 --- a/tests/common/test_simple_client.py +++ b/tests/common/test_simple_client.py @@ -18,12 +18,12 @@ class TestSimpleClient(unittest.TestCase): client = SimpleClient(123, a='b') with mock.patch('socketio.simple_client.Client') as mock_client: client.connect('url', headers='h', auth='a', transports='t', - namespace='n', socketio_path='s') + namespace='n', socketio_path='s', wait_timeout='w') mock_client.assert_called_once_with(123, a='b') assert client.client == mock_client() mock_client().connect.assert_called_once_with( 'url', headers='h', auth='a', transports='t', - namespaces=['n'], socketio_path='s') + namespaces=['n'], socketio_path='s', wait_timeout='w') mock_client().event.call_count == 3 mock_client().on.called_once_with('*') assert client.namespace == 'n' @@ -33,12 +33,13 @@ class TestSimpleClient(unittest.TestCase): with SimpleClient(123, a='b') as client: with mock.patch('socketio.simple_client.Client') as mock_client: client.connect('url', headers='h', auth='a', transports='t', - namespace='n', socketio_path='s') + namespace='n', socketio_path='s', + wait_timeout='w') mock_client.assert_called_once_with(123, a='b') assert client.client == mock_client() mock_client().connect.assert_called_once_with( 'url', headers='h', auth='a', transports='t', - namespaces=['n'], socketio_path='s') + namespaces=['n'], socketio_path='s', wait_timeout='w') mock_client().event.call_count == 3 mock_client().on.called_once_with('*') assert client.namespace == 'n' @@ -54,7 +55,8 @@ class TestSimpleClient(unittest.TestCase): def test_properties(self): client = SimpleClient() - client.client = mock.MagicMock(sid='sid', transport='websocket') + client.client = mock.MagicMock(transport='websocket') + client.client.get_sid.return_value = 'sid' client.connected_event.set() client.connected = True diff --git a/tests/web_server.py b/tests/web_server.py new file mode 100644 index 0000000..cb24668 --- /dev/null +++ b/tests/web_server.py @@ -0,0 +1,81 @@ +import threading +import time +from socketserver import ThreadingMixIn +from wsgiref.simple_server import make_server, WSGIServer, WSGIRequestHandler +import requests +import socketio + + +class SocketIOWebServer: + """A simple web server used for running Socket.IO servers in tests. + + :param sio: a Socket.IO server instance. + + Note 1: This class is not production-ready and is intended for testing. + Note 2: This class only supports the "threading" async_mode, with WebSocket + support provided by the simple-websocket package. + """ + def __init__(self, sio): + if sio.async_mode != 'threading': + raise ValueError('The async_mode must be "threading"') + + def http_app(environ, start_response): + start_response('200 OK', [('Content-Type', 'text/plain')]) + return [b'OK'] + + self.sio = sio + self.app = socketio.WSGIApp(sio, http_app) + self.httpd = None + self.thread = None + + def start(self, port=8900): + """Start the web server. + + :param port: the port to listen on. Defaults to 8900. + + The server is started in a background thread. + """ + class ThreadingWSGIServer(ThreadingMixIn, WSGIServer): + pass + + class WebSocketRequestHandler(WSGIRequestHandler): + def get_environ(self): + env = super().get_environ() + + # pass the raw socket to the WSGI app so that it can be used + # by WebSocket connections (hack copied from gunicorn) + env['gunicorn.socket'] = self.connection + return env + + self.httpd = make_server('', port, self._app_wrapper, + ThreadingWSGIServer, WebSocketRequestHandler) + self.thread = threading.Thread(target=self.httpd.serve_forever) + self.thread.start() + + # wait for the server to start + while True: + try: + r = requests.get(f'http://localhost:{port}/') + r.raise_for_status() + if r.text == 'OK': + break + except: + time.sleep(0.1) + + def stop(self): + """Stop the web server.""" + self.sio.shutdown() + self.httpd.shutdown() + self.httpd.server_close() + self.thread.join() + self.httpd = None + self.thread = None + + def _app_wrapper(self, environ, start_response): + try: + return self.app(environ, start_response) + except StopIteration: + # end the WebSocket request without sending a response + # (this is a hack that was copied from gunicorn's threaded worker) + start_response('200 OK', []) + return [] diff --git a/tox.ini b/tox.ini index 12d2891..2929457 100644 --- a/tox.ini +++ b/tox.ini @@ -14,10 +14,16 @@ python = [testenv] commands= pip install -e . - pytest -p no:logging --cov=socketio --cov-branch --cov-report=term-missing --cov-report=xml + pytest -p no:logging --timeout=60 --cov=socketio --cov-branch --cov-report=term-missing --cov-report=xml deps= + simple-websocket + uvicorn + requests + websocket-client + aiohttp msgpack pytest + pytest-timeout pytest-cov [testenv:flake8]