diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index af896fb..ddc190c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -26,7 +26,7 @@ jobs: exclude: # pypy3 currently fails to run on Windows - os: windows-latest - python: pypy-3.8 + python: pypy-3.9 fail-fast: false runs-on: ${{ matrix.os }} steps: diff --git a/src/socketio/admin.py b/src/socketio/admin.py index dc45c4b..c2d2a7c 100644 --- a/src/socketio/admin.py +++ b/src/socketio/admin.py @@ -63,7 +63,7 @@ class InstrumentedServer: namespace=self.admin_namespace) if self.mode == 'development': - if not self.read_only: + if not self.read_only: # pragma: no branch self.sio.on('emit', self.admin_emit, namespace=self.admin_namespace) self.sio.on('join', self.admin_enter_room, @@ -117,8 +117,22 @@ class InstrumentedServer: Socket._websocket_handler = functools.partialmethod( self.__class__._eio_websocket_handler, self) + def uninstrument(self): # pragma: no cover + if self.mode == 'development': + self.sio.manager.connect = self.sio.manager.__connect + self.sio.manager.disconnect = self.sio.manager.__disconnect + self.sio.manager.enter_room = self.sio.manager.__enter_room + self.sio.manager.leave_room = self.sio.manager.__leave_room + self.sio.manager.emit = self.sio.manager.__emit + self.sio._handle_event_internal = self.sio.__handle_event_internal + self.sio.eio._ok = self.sio.eio.__ok + + from engineio.socket import Socket + Socket.handle_post_request = Socket.__handle_post_request + Socket._websocket_handler = Socket.__websocket_handler + def admin_connect(self, sid, environ, client_auth): - if self.auth != None: + if self.auth: authenticated = False if isinstance(self.auth, dict): authenticated = client_auth == self.auth @@ -175,8 +189,9 @@ class InstrumentedServer: self.sio.disconnect(sid, namespace=namespace) def shutdown(self): - self.stop_stats_event.set() - self.stats_thread.join() + if self.stats_task: # pragma: no branch + self.stop_stats_event.set() + self.stats_task.join() def _connect(self, eio_sid, namespace): sid = self.sio.manager.__connect(eio_sid, namespace) @@ -188,22 +203,9 @@ class InstrumentedServer: datetime.utcfromtimestamp(t).isoformat() + 'Z', ), namespace=self.admin_namespace) - def check_for_upgrade(): - for _ in range(5): - self.sio.sleep(5) - try: - if self.sio.eio._get_socket(eio_sid).upgraded: - self.sio.emit('socket_updated', { - 'id': sid, - 'nsp': namespace, - 'transport': 'websocket', - }, namespace=self.admin_namespace) - break - except KeyError: - pass - - if serialized_socket['transport'] == 'polling': - self.sio.start_background_task(check_for_upgrade) + if serialized_socket['transport'] == 'polling': # pragma: no cover + self.sio.start_background_task( + self._check_for_upgrade, eio_sid, sid, namespace) return sid def _disconnect(self, sid, namespace, **kwargs): @@ -216,6 +218,20 @@ class InstrumentedServer: ), namespace=self.admin_namespace) return self.sio.manager.__disconnect(sid, namespace, **kwargs) + def _check_for_upgrade(self, eio_sid, sid, namespace): # pragma: no cover + for _ in range(5): + self.sio.sleep(5) + try: + if self.sio.eio._get_socket(eio_sid).upgraded: + self.sio.emit('socket_updated', { + 'id': sid, + 'nsp': namespace, + 'transport': 'websocket', + }, namespace=self.admin_namespace) + break + except KeyError: + pass + def _enter_room(self, sid, namespace, room, eio_sid=None): ret = self.sio.manager.__enter_room(sid, namespace, room, eio_sid) if room: @@ -245,7 +261,7 @@ class InstrumentedServer: if namespace != self.admin_namespace: event_data = [event] + list(data) if isinstance(data, tuple) \ else [data] - if not isinstance(skip_sid, list): + if not isinstance(skip_sid, list): # pragma: no branch skip_sid = [skip_sid] for sid, _ in self.sio.manager.get_participants(namespace, room): if sid not in skip_sid: @@ -328,15 +344,17 @@ class InstrumentedServer: 'namespaces': [{ 'name': nsp, 'socketsCount': len(self.sio.manager.rooms.get( - nsp, {None: []})[None]) + nsp, {None: []}).get(None, [])) } for nsp in namespaces], }, namespace=self.admin_namespace) def serialize_socket(self, sid, namespace, eio_sid=None): - if eio_sid is None: + if eio_sid is None: # pragma: no cover eio_sid = self.sio.manager.eio_sid_from_sid(sid) socket = self.sio.eio._get_socket(eio_sid) environ = self.sio.environ.get(eio_sid, {}) + tm = self.sio.manager._timestamps[sid] if sid in \ + self.sio.manager._timestamps else 0 return { 'id': sid, 'clientId': eio_sid, @@ -351,9 +369,9 @@ class InstrumentedServer: environ.get('QUERY_STRING', '')).items()}, 'secure': environ.get('wsgi.url_scheme', '') == 'https', 'url': environ.get('PATH_INFO', ''), - 'issued': self.sio.manager._timestamps[sid] * 1000, - 'time': datetime.utcfromtimestamp( - self.sio.manager._timestamps[sid]).isoformat() + 'Z', + 'issued': tm * 1000, + 'time': datetime.utcfromtimestamp(tm).isoformat() + 'Z' + if tm else '', }, 'rooms': self.sio.manager.get_rooms(sid, namespace), } diff --git a/src/socketio/async_simple_client.py b/src/socketio/async_simple_client.py index b21418a..db90784 100644 --- a/src/socketio/async_simple_client.py +++ b/src/socketio/async_simple_client.py @@ -23,7 +23,8 @@ class AsyncSimpleClient: self.input_buffer = [] async def connect(self, url, headers={}, auth=None, transports=None, - namespace='/', socketio_path='socket.io'): + namespace='/', socketio_path='socket.io', + wait_timeout=5): """Connect to a Socket.IO server. :param url: The URL of the Socket.IO server. It can include custom @@ -49,6 +50,8 @@ class AsyncSimpleClient: :param socketio_path: The endpoint where the Socket.IO server is installed. The default value is appropriate for most cases. + :param wait_timeout: How long the client should wait for the + connection. The default is 5 seconds. Note: this method is a coroutine. """ @@ -80,7 +83,8 @@ class AsyncSimpleClient: await self.client.connect( url, headers=headers, auth=auth, transports=transports, - namespaces=[namespace], socketio_path=socketio_path) + namespaces=[namespace], socketio_path=socketio_path, + wait_timeout=wait_timeout) @property def sid(self): @@ -89,7 +93,7 @@ class AsyncSimpleClient: The session ID is not guaranteed to remain constant throughout the life of the connection, as reconnections can cause it to change. """ - return self.client.sid if self.client else None + return self.client.get_sid(self.namespace) if self.client else None @property def transport(self): diff --git a/src/socketio/asyncio_admin.py b/src/socketio/asyncio_admin.py index b491e34..3199e10 100644 --- a/src/socketio/asyncio_admin.py +++ b/src/socketio/asyncio_admin.py @@ -44,7 +44,7 @@ class InstrumentedAsyncServer: namespace=self.admin_namespace) if self.mode == 'development': - if not self.read_only: + if not self.read_only: # pragma: no branch self.sio.on('emit', self.admin_emit, namespace=self.admin_namespace) self.sio.on('join', self.admin_enter_room, @@ -89,7 +89,8 @@ class InstrumentedAsyncServer: from engineio.asyncio_socket import AsyncSocket self.sio.eio.__ok = self.sio.eio._ok self.sio.eio._ok = self._eio_http_response - AsyncSocket.__handle_post_request = functools.partialmethod( + AsyncSocket.__handle_post_request = AsyncSocket.handle_post_request + AsyncSocket.handle_post_request = functools.partialmethod( self.__class__._eio_handle_post_request, self) # report websocket packets @@ -97,9 +98,23 @@ class InstrumentedAsyncServer: AsyncSocket._websocket_handler = functools.partialmethod( self.__class__._eio_websocket_handler, self) + def uninstrument(self): # pragma: no cover + if self.mode == 'development': + self.sio.manager.connect = self.sio.manager.__connect + self.sio.manager.disconnect = self.sio.manager.__disconnect + self.sio.manager.enter_room = self.sio.manager.__enter_room + self.sio.manager.leave_room = self.sio.manager.__leave_room + self.sio.manager.emit = self.sio.manager.__emit + self.sio._handle_event_internal = self.sio.__handle_event_internal + self.sio.eio._ok = self.sio.eio.__ok + + from engineio.asyncio_socket import AsyncSocket + AsyncSocket.handle_post_request = AsyncSocket.__handle_post_request + AsyncSocket._websocket_handler = AsyncSocket.__websocket_handler + async def admin_connect(self, sid, environ, client_auth): authenticated = True - if self.auth != None: + if self.auth: authenticated = False if isinstance(self.auth, dict): authenticated = client_auth == self.auth @@ -159,8 +174,9 @@ class InstrumentedAsyncServer: await self.sio.disconnect(sid, namespace=namespace) async def shutdown(self): - self.stop_stats_event.set() - await asyncio.gather(self.stats_task) + if self.stats_task: # pragma: no branch + self.stop_stats_event.set() + await asyncio.gather(self.stats_task) async def _connect(self, eio_sid, namespace): sid = await self.sio.manager.__connect(eio_sid, namespace) @@ -172,22 +188,9 @@ class InstrumentedAsyncServer: datetime.utcfromtimestamp(t).isoformat() + 'Z', ), namespace=self.admin_namespace) - async def check_for_upgrade(): - for _ in range(5): - await self.sio.sleep(5) - try: - if self.sio.eio._get_socket(eio_sid).upgraded: - await self.sio.emit('socket_updated', { - 'id': sid, - 'nsp': namespace, - 'transport': 'websocket', - }, namespace=self.admin_namespace) - break - except KeyError: - pass - if serialized_socket['transport'] == 'polling': - self.sio.start_background_task(check_for_upgrade) + self.sio.start_background_task( + self._check_for_upgrade, eio_sid, sid, namespace) return sid async def _disconnect(self, sid, namespace, **kwargs): @@ -200,6 +203,21 @@ class InstrumentedAsyncServer: ), namespace=self.admin_namespace) return await self.sio.manager.__disconnect(sid, namespace, **kwargs) + async def _check_for_upgrade(self, eio_sid, sid, + namespace): # pragma: no cover + for _ in range(5): + await self.sio.sleep(5) + try: + if self.sio.eio._get_socket(eio_sid).upgraded: + await self.sio.emit('socket_updated', { + 'id': sid, + 'nsp': namespace, + 'transport': 'websocket', + }, namespace=self.admin_namespace) + break + except KeyError: + pass + def _enter_room(self, sid, namespace, room, eio_sid=None): ret = self.sio.manager.__enter_room(sid, namespace, room, eio_sid) if room: @@ -223,13 +241,13 @@ class InstrumentedAsyncServer: async def _emit(self, event, data, namespace, room=None, skip_sid=None, callback=None, **kwargs): - ret = await self.sio.manager.__emit(event, data, namespace, room=room, - skip_sid=skip_sid, callback=callback, - **kwargs) + ret = await self.sio.manager.__emit( + event, data, namespace, room=room, skip_sid=skip_sid, + callback=callback, **kwargs) if namespace != self.admin_namespace: event_data = [event] + list(data) if isinstance(data, tuple) \ else [data] - if not isinstance(skip_sid, list): + if not isinstance(skip_sid, list): # pragma: no branch skip_sid = [skip_sid] for sid, _ in self.sio.manager.get_participants(namespace, room): if sid not in skip_sid: @@ -312,7 +330,7 @@ class InstrumentedAsyncServer: 'namespaces': [{ 'name': nsp, 'socketsCount': len(self.sio.manager.rooms.get( - nsp, {None: []})[None]) + nsp, {None: []}).get(None, [])) } for nsp in namespaces], }, namespace=self.admin_namespace) while self.admin_queue: @@ -321,10 +339,12 @@ class InstrumentedAsyncServer: namespace=self.admin_namespace) def serialize_socket(self, sid, namespace, eio_sid=None): - if eio_sid is None: + if eio_sid is None: # pragma: no cover eio_sid = self.sio.manager.eio_sid_from_sid(sid) socket = self.sio.eio._get_socket(eio_sid) environ = self.sio.environ.get(eio_sid, {}) + tm = self.sio.manager._timestamps[sid] if sid in \ + self.sio.manager._timestamps else 0 return { 'id': sid, 'clientId': eio_sid, @@ -339,9 +359,9 @@ class InstrumentedAsyncServer: environ.get('QUERY_STRING', '')).items()}, 'secure': environ.get('wsgi.url_scheme', '') == 'https', 'url': environ.get('PATH_INFO', ''), - 'issued': self.sio.manager._timestamps[sid] * 1000, - 'time': datetime.utcfromtimestamp( - self.sio.manager._timestamps[sid]).isoformat() + 'Z', + 'issued': tm * 1000, + 'time': datetime.utcfromtimestamp(tm).isoformat() + 'Z' + if tm else '', }, 'rooms': self.sio.manager.get_rooms(sid, namespace), } diff --git a/src/socketio/base_manager.py b/src/socketio/base_manager.py index 6e145a1..d1b0a08 100644 --- a/src/socketio/base_manager.py +++ b/src/socketio/base_manager.py @@ -30,7 +30,7 @@ class BaseManager: def get_participants(self, namespace, room): """Return an iterable with the active participants in a room.""" - ns = self.rooms[namespace] + ns = self.rooms.get(namespace, {}) if hasattr(room, '__len__') and not isinstance(room, str): participants = ns[room[0]]._fwdm.copy() if room[0] in ns else {} for r in room[1:]: diff --git a/src/socketio/simple_client.py b/src/socketio/simple_client.py index 4a88380..ce3a1c5 100644 --- a/src/socketio/simple_client.py +++ b/src/socketio/simple_client.py @@ -23,7 +23,7 @@ class SimpleClient: self.input_buffer = [] def connect(self, url, headers={}, auth=None, transports=None, - namespace='/', socketio_path='socket.io'): + namespace='/', socketio_path='socket.io', wait_timeout=5): """Connect to a Socket.IO server. :param url: The URL of the Socket.IO server. It can include custom @@ -49,6 +49,9 @@ class SimpleClient: :param socketio_path: The endpoint where the Socket.IO server is installed. The default value is appropriate for most cases. + :param wait_timeout: How long the client should wait for the + connection to be established. The default is 5 + seconds. """ if self.connected: raise RuntimeError('Already connected') @@ -78,7 +81,8 @@ class SimpleClient: self.client.connect(url, headers=headers, auth=auth, transports=transports, namespaces=[namespace], - socketio_path=socketio_path) + socketio_path=socketio_path, + wait_timeout=wait_timeout) @property def sid(self): @@ -87,7 +91,7 @@ class SimpleClient: The session ID is not guaranteed to remain constant throughout the life of the connection, as reconnections can cause it to change. """ - return self.client.sid if self.client else None + return self.client.get_sid(self.namespace) if self.client else None @property def transport(self): diff --git a/tests/async/test_asyncio_admin.py b/tests/async/test_asyncio_admin.py new file mode 100644 index 0000000..aaad2cf --- /dev/null +++ b/tests/async/test_asyncio_admin.py @@ -0,0 +1,299 @@ +from functools import wraps +import threading +import time +from unittest import mock +import unittest +import pytest +from engineio.asyncio_socket import AsyncSocket as EngineIOSocket +import socketio +from socketio.exceptions import ConnectionError +from tests.asyncio_web_server import SocketIOWebServer +from .helpers import AsyncMock + + +def with_instrumented_server(auth=False, **ikwargs): + """This decorator can be applied to test functions or methods so that they + run with a Socket.IO server that has been instrumented for the official + Admin UI project. The arguments passed to the decorator are passed directly + to the ``instrument()`` method of the server. + """ + def decorator(f): + @wraps(f) + def wrapped(self, *args, **kwargs): + sio = socketio.AsyncServer(async_mode='asgi') + instrumented_server = sio.instrument(auth=auth, **ikwargs) + + @sio.event + def enter_room(sid, data): + sio.enter_room(sid, data) + + @sio.event + async def emit(sid, event): + await sio.emit(event, skip_sid=sid) + + @sio.event(namespace='/foo') + def connect(sid, environ, auth): + pass + + async def shutdown(): + await instrumented_server.shutdown() + await sio.shutdown() + + server = SocketIOWebServer(sio, on_shutdown=shutdown) + server.start() + + # import logging + # logging.getLogger('engineio.client').setLevel(logging.DEBUG) + # logging.getLogger('socketio.client').setLevel(logging.DEBUG) + + original_schedule_ping = EngineIOSocket.schedule_ping + EngineIOSocket.schedule_ping = mock.MagicMock() + + try: + ret = f(self, instrumented_server, *args, **kwargs) + finally: + server.stop() + instrumented_server.uninstrument() + + EngineIOSocket.schedule_ping = original_schedule_ping + + # import logging + # logging.getLogger('engineio.client').setLevel(logging.NOTSET) + # logging.getLogger('socketio.client').setLevel(logging.NOTSET) + + return ret + return wrapped + return decorator + + +def _custom_auth(auth): + return auth == {'foo': 'bar'} + + +async def _async_custom_auth(auth): + return auth == {'foo': 'bar'} + + +class TestAsyncAdmin(unittest.TestCase): + def setUp(self): + print('threads at start:', threading.enumerate()) + self.thread_count = threading.active_count() + + def tearDown(self): + print('threads at end:', threading.enumerate()) + assert self.thread_count == threading.active_count() + + def test_missing_auth(self): + sio = socketio.AsyncServer(async_mode='asgi') + with pytest.raises(ValueError): + sio.instrument() + + @with_instrumented_server(auth=False) + def test_admin_connect_with_no_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin') + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + + @with_instrumented_server(auth={'foo': 'bar'}) + def test_admin_connect_with_dict_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect( + 'http://localhost:8900', namespace='/admin', + auth={'foo': 'baz'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect( + 'http://localhost:8900', namespace='/admin') + + @with_instrumented_server(auth=[{'foo': 'bar'}, + {'u': 'admin', 'p': 'secret'}]) + def test_admin_connect_with_list_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'u': 'admin', 'p': 'secret'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin', auth={'foo': 'baz'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin') + + @with_instrumented_server(auth=_custom_auth) + def test_admin_connect_with_function_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin', auth={'foo': 'baz'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin') + + @with_instrumented_server(auth=_async_custom_auth) + def test_admin_connect_with_async_function_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin', auth={'foo': 'baz'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin') + + @with_instrumented_server() + def test_admin_connect_only_admin(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin') + sid = admin_client.sid + expected = ['config', 'all_sockets', 'server_stats'] + events = {} + while expected: + data = admin_client.receive(timeout=5) + if data[0] in expected: + events[data[0]] = data[1] + expected.remove(data[0]) + + assert 'supportedFeatures' in events['config'] + assert 'ALL_EVENTS' in events['config']['supportedFeatures'] + assert len(events['all_sockets']) == 1 + assert events['all_sockets'][0]['id'] == sid + assert events['all_sockets'][0]['rooms'] == [sid] + assert events['server_stats']['clientsCount'] == 1 + assert events['server_stats']['pollingClientsCount'] == 0 + assert len(events['server_stats']['namespaces']) == 3 + assert {'name': '/', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/foo', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/admin', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + + @with_instrumented_server() + def test_admin_connect_with_others(self, isvr): + with socketio.SimpleClient() as client1, \ + socketio.SimpleClient() as client2, \ + socketio.SimpleClient() as client3, \ + socketio.SimpleClient() as admin_client: + client1.connect('http://localhost:8900') + client1.emit('enter_room', 'room') + sid1 = client1.sid + + saved_check_for_upgrade = isvr._check_for_upgrade + isvr._check_for_upgrade = AsyncMock() + client2.connect('http://localhost:8900', namespace='/foo', + transports=['polling']) + sid2 = client2.sid + isvr._check_for_upgrade = saved_check_for_upgrade + + client3.connect('http://localhost:8900', namespace='/admin') + sid3 = client3.sid + + admin_client.connect('http://localhost:8900', namespace='/admin') + sid = admin_client.sid + expected = ['config', 'all_sockets', 'server_stats'] + events = {} + while expected: + data = admin_client.receive(timeout=5) + if data[0] in expected: + events[data[0]] = data[1] + expected.remove(data[0]) + + assert 'supportedFeatures' in events['config'] + assert 'ALL_EVENTS' in events['config']['supportedFeatures'] + assert len(events['all_sockets']) == 4 + assert events['server_stats']['clientsCount'] == 4 + assert events['server_stats']['pollingClientsCount'] == 1 + assert len(events['server_stats']['namespaces']) == 3 + assert {'name': '/', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + assert {'name': '/foo', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + assert {'name': '/admin', 'socketsCount': 2} in \ + events['server_stats']['namespaces'] + + for socket in events['all_sockets']: + if socket['id'] == sid: + assert socket['rooms'] == [sid] + elif socket['id'] == sid1: + assert socket['rooms'] == [sid1, 'room'] + elif socket['id'] == sid2: + assert socket['rooms'] == [sid2] + elif socket['id'] == sid3: + assert socket['rooms'] == [sid3] + + @with_instrumented_server(mode='production') + def test_admin_connect_production(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin') + expected = ['config', 'server_stats'] + events = {} + while expected: + data = admin_client.receive(timeout=5) + if data[0] in expected: + events[data[0]] = data[1] + expected.remove(data[0]) + + assert 'supportedFeatures' in events['config'] + assert 'ALL_EVENTS' not in events['config']['supportedFeatures'] + assert events['server_stats']['clientsCount'] == 1 + assert events['server_stats']['pollingClientsCount'] == 0 + assert len(events['server_stats']['namespaces']) == 3 + assert {'name': '/', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/foo', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/admin', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + + @with_instrumented_server() + def test_admin_features(self, isvr): + with socketio.SimpleClient() as client1, \ + socketio.SimpleClient() as client2, \ + socketio.SimpleClient() as admin_client: + client1.connect('http://localhost:8900') + client2.connect('http://localhost:8900') + admin_client.connect('http://localhost:8900', namespace='/admin') + + # emit from admin + admin_client.emit( + 'emit', ('/', client1.sid, 'foo', {'bar': 'baz'}, 'extra')) + data = client1.receive(timeout=5) + assert data == ['foo', {'bar': 'baz'}, 'extra'] + + # emit from regular client + client1.emit('emit', 'foo') + data = client2.receive(timeout=5) + assert data == ['foo'] + + # join and leave + admin_client.emit('join', ('/', 'room', client1.sid)) + admin_client.emit( + 'emit', ('/', 'room', 'foo', {'bar': 'baz'})) + data = client1.receive(timeout=5) + assert data == ['foo', {'bar': 'baz'}] + admin_client.emit('leave', ('/', 'room')) + + # disconnect + admin_client.emit('_disconnect', ('/', False, client1.sid)) + for _ in range(10): + if not client1.connected: + break + time.sleep(0.2) + assert not client1.connected diff --git a/tests/async/test_manager.py b/tests/async/test_manager.py index 306734d..0b2fc93 100644 --- a/tests/async/test_manager.py +++ b/tests/async/test_manager.py @@ -353,7 +353,7 @@ class TestAsyncManager(unittest.TestCase): _run(self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo')) def test_emit_with_tuple(self): - sid = self.bm.connect('123', '/foo') + sid = _run(self.bm.connect('123', '/foo')) _run( self.bm.emit( 'my event', ('foo', 'bar'), namespace='/foo', room=sid @@ -366,7 +366,7 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event","foo","bar"]' def test_emit_with_list(self): - sid = self.bm.connect('123', '/foo') + sid = _run(self.bm.connect('123', '/foo')) _run( self.bm.emit( 'my event', ['foo', 'bar'], namespace='/foo', room=sid @@ -379,7 +379,7 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event",["foo","bar"]]' def test_emit_with_none(self): - sid = self.bm.connect('123', '/foo') + sid = _run(self.bm.connect('123', '/foo')) _run( self.bm.emit( 'my event', None, namespace='/foo', room=sid @@ -392,7 +392,7 @@ class TestAsyncManager(unittest.TestCase): assert pkt.encode() == '42/foo,["my event"]' def test_emit_binary(self): - sid = self.bm.connect('123', '/') + sid = _run(self.bm.connect('123', '/')) _run( self.bm.emit( u'my event', b'my binary data', namespace='/', room=sid diff --git a/tests/async/test_simple_client.py b/tests/async/test_simple_client.py index 397ae98..f8996bd 100644 --- a/tests/async/test_simple_client.py +++ b/tests/async/test_simple_client.py @@ -24,12 +24,13 @@ class TestAsyncAsyncSimpleClient(unittest.TestCase): mock_client.return_value.connect = AsyncMock() _run(client.connect('url', headers='h', auth='a', transports='t', - namespace='n', socketio_path='s')) + namespace='n', socketio_path='s', + wait_timeout='w')) mock_client.assert_called_once_with(123, a='b') assert client.client == mock_client() mock_client().connect.mock.assert_called_once_with( 'url', headers='h', auth='a', transports='t', - namespaces=['n'], socketio_path='s') + namespaces=['n'], socketio_path='s', wait_timeout='w') mock_client().event.call_count == 3 mock_client().on.called_once_with('*') assert client.namespace == 'n' @@ -44,12 +45,12 @@ class TestAsyncAsyncSimpleClient(unittest.TestCase): await client.connect('url', headers='h', auth='a', transports='t', namespace='n', - socketio_path='s') + socketio_path='s', wait_timeout='w') mock_client.assert_called_once_with(123, a='b') assert client.client == mock_client() mock_client().connect.mock.assert_called_once_with( 'url', headers='h', auth='a', transports='t', - namespaces=['n'], socketio_path='s') + namespaces=['n'], socketio_path='s', wait_timeout='w') mock_client().event.call_count == 3 mock_client().on.called_once_with('*') assert client.namespace == 'n' @@ -67,7 +68,8 @@ class TestAsyncAsyncSimpleClient(unittest.TestCase): def test_properties(self): client = AsyncSimpleClient() - client.client = mock.MagicMock(sid='sid', transport='websocket') + client.client = mock.MagicMock(transport='websocket') + client.client.get_sid.return_value = 'sid' client.connected_event.set() client.connected = True diff --git a/tests/asyncio_web_server.py b/tests/asyncio_web_server.py new file mode 100644 index 0000000..8b2046c --- /dev/null +++ b/tests/asyncio_web_server.py @@ -0,0 +1,57 @@ +import requests +import threading +import time +import uvicorn +import socketio + + +class SocketIOWebServer: + """A simple web server used for running Socket.IO servers in tests. + + :param sio: a Socket.IO server instance. + + Note 1: This class is not production-ready and is intended for testing. + Note 2: This class only supports the "asgi" async_mode. + """ + def __init__(self, sio, on_shutdown=None): + if sio.async_mode != 'asgi': + raise ValueError('The async_mode must be "asgi"') + + async def http_app(scope, receive, send): + await send({'type': 'http.response.start', + 'status': 200, + 'headers': [('Content-Type', 'text/plain')]}) + await send({'type': 'http.response.body', + 'body': b'OK'}) + + self.sio = sio + self.app = socketio.ASGIApp(sio, http_app, on_shutdown=on_shutdown) + self.httpd = None + self.thread = None + + def start(self, port=8900): + """Start the web server. + + :param port: the port to listen on. Defaults to 8900. + + The server is started in a background thread. + """ + self.httpd = uvicorn.Server(config=uvicorn.Config(self.app, port=port)) + self.thread = threading.Thread(target=self.httpd.run) + self.thread.start() + + # wait for the server to start + while True: + try: + r = requests.get(f'http://localhost:{port}/') + r.raise_for_status() + if r.text == 'OK': + break + except: + time.sleep(0.1) + + def stop(self): + """Stop the web server.""" + self.httpd.should_exit = True + self.thread.join() + self.thread = None diff --git a/tests/common/test_admin.py b/tests/common/test_admin.py new file mode 100644 index 0000000..658d3a3 --- /dev/null +++ b/tests/common/test_admin.py @@ -0,0 +1,277 @@ +from functools import wraps +import threading +import time +from unittest import mock +import unittest +import pytest +from engineio.socket import Socket as EngineIOSocket +import socketio +from socketio.exceptions import ConnectionError +from tests.web_server import SocketIOWebServer + + +def with_instrumented_server(auth=False, **ikwargs): + """This decorator can be applied to test functions or methods so that they + run with a Socket.IO server that has been instrumented for the official + Admin UI project. The arguments passed to the decorator are passed directly + to the ``instrument()`` method of the server. + """ + def decorator(f): + @wraps(f) + def wrapped(self, *args, **kwargs): + sio = socketio.Server(async_mode='threading') + instrumented_server = sio.instrument(auth=auth, **ikwargs) + + @sio.event + def enter_room(sid, data): + sio.enter_room(sid, data) + + @sio.event + def emit(sid, event): + sio.emit(event, skip_sid=sid) + + @sio.event(namespace='/foo') + def connect(sid, environ, auth): + pass + + server = SocketIOWebServer(sio) + server.start() + + # import logging + # logging.getLogger('engineio.client').setLevel(logging.DEBUG) + # logging.getLogger('socketio.client').setLevel(logging.DEBUG) + + original_schedule_ping = EngineIOSocket.schedule_ping + EngineIOSocket.schedule_ping = mock.MagicMock() + + try: + ret = f(self, instrumented_server, *args, **kwargs) + finally: + server.stop() + instrumented_server.shutdown() + instrumented_server.uninstrument() + + EngineIOSocket.schedule_ping = original_schedule_ping + + # import logging + # logging.getLogger('engineio.client').setLevel(logging.NOTSET) + # logging.getLogger('socketio.client').setLevel(logging.NOTSET) + + return ret + return wrapped + return decorator + + +def _custom_auth(auth): + return auth == {'foo': 'bar'} + + +class TestAdmin(unittest.TestCase): + def setUp(self): + print('threads at start:', threading.enumerate()) + self.thread_count = threading.active_count() + + def tearDown(self): + print('threads at end:', threading.enumerate()) + assert self.thread_count == threading.active_count() + + def test_missing_auth(self): + sio = socketio.Server(async_mode='threading') + with pytest.raises(ValueError): + sio.instrument() + + @with_instrumented_server(auth=False) + def test_admin_connect_with_no_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin') + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + + @with_instrumented_server(auth={'foo': 'bar'}) + def test_admin_connect_with_dict_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect( + 'http://localhost:8900', namespace='/admin', + auth={'foo': 'baz'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect( + 'http://localhost:8900', namespace='/admin') + + @with_instrumented_server(auth=[{'foo': 'bar'}, + {'u': 'admin', 'p': 'secret'}]) + def test_admin_connect_with_list_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'u': 'admin', 'p': 'secret'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin', auth={'foo': 'baz'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin') + + @with_instrumented_server(auth=_custom_auth) + def test_admin_connect_with_function_auth(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin', + auth={'foo': 'bar'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin', auth={'foo': 'baz'}) + with socketio.SimpleClient() as admin_client: + with pytest.raises(ConnectionError): + admin_client.connect('http://localhost:8900', + namespace='/admin') + + @with_instrumented_server() + def test_admin_connect_only_admin(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin') + sid = admin_client.sid + expected = ['config', 'all_sockets', 'server_stats'] + events = {} + while expected: + data = admin_client.receive(timeout=5) + if data[0] in expected: + events[data[0]] = data[1] + expected.remove(data[0]) + + assert 'supportedFeatures' in events['config'] + assert 'ALL_EVENTS' in events['config']['supportedFeatures'] + assert len(events['all_sockets']) == 1 + assert events['all_sockets'][0]['id'] == sid + assert events['all_sockets'][0]['rooms'] == [sid] + assert events['server_stats']['clientsCount'] == 1 + assert events['server_stats']['pollingClientsCount'] == 0 + assert len(events['server_stats']['namespaces']) == 3 + assert {'name': '/', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/foo', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/admin', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + + @with_instrumented_server() + def test_admin_connect_with_others(self, isvr): + with socketio.SimpleClient() as client1, \ + socketio.SimpleClient() as client2, \ + socketio.SimpleClient() as client3, \ + socketio.SimpleClient() as admin_client: + client1.connect('http://localhost:8900') + client1.emit('enter_room', 'room') + sid1 = client1.sid + + saved_check_for_upgrade = isvr._check_for_upgrade + isvr._check_for_upgrade = mock.MagicMock() + client2.connect('http://localhost:8900', namespace='/foo', + transports=['polling']) + sid2 = client2.sid + isvr._check_for_upgrade = saved_check_for_upgrade + + client3.connect('http://localhost:8900', namespace='/admin') + sid3 = client3.sid + + admin_client.connect('http://localhost:8900', namespace='/admin') + sid = admin_client.sid + expected = ['config', 'all_sockets', 'server_stats'] + events = {} + while expected: + data = admin_client.receive(timeout=5) + if data[0] in expected: + events[data[0]] = data[1] + expected.remove(data[0]) + + assert 'supportedFeatures' in events['config'] + assert 'ALL_EVENTS' in events['config']['supportedFeatures'] + assert len(events['all_sockets']) == 4 + assert events['server_stats']['clientsCount'] == 4 + assert events['server_stats']['pollingClientsCount'] == 1 + assert len(events['server_stats']['namespaces']) == 3 + assert {'name': '/', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + assert {'name': '/foo', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + assert {'name': '/admin', 'socketsCount': 2} in \ + events['server_stats']['namespaces'] + + for socket in events['all_sockets']: + if socket['id'] == sid: + assert socket['rooms'] == [sid] + elif socket['id'] == sid1: + assert socket['rooms'] == [sid1, 'room'] + elif socket['id'] == sid2: + assert socket['rooms'] == [sid2] + elif socket['id'] == sid3: + assert socket['rooms'] == [sid3] + + @with_instrumented_server(mode='production') + def test_admin_connect_production(self, isvr): + with socketio.SimpleClient() as admin_client: + admin_client.connect('http://localhost:8900', namespace='/admin') + expected = ['config', 'server_stats'] + events = {} + while expected: + data = admin_client.receive(timeout=5) + if data[0] in expected: + events[data[0]] = data[1] + expected.remove(data[0]) + + assert 'supportedFeatures' in events['config'] + assert 'ALL_EVENTS' not in events['config']['supportedFeatures'] + assert events['server_stats']['clientsCount'] == 1 + assert events['server_stats']['pollingClientsCount'] == 0 + assert len(events['server_stats']['namespaces']) == 3 + assert {'name': '/', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/foo', 'socketsCount': 0} in \ + events['server_stats']['namespaces'] + assert {'name': '/admin', 'socketsCount': 1} in \ + events['server_stats']['namespaces'] + + @with_instrumented_server() + def test_admin_features(self, isvr): + with socketio.SimpleClient() as client1, \ + socketio.SimpleClient() as client2, \ + socketio.SimpleClient() as admin_client: + client1.connect('http://localhost:8900') + client2.connect('http://localhost:8900') + admin_client.connect('http://localhost:8900', namespace='/admin') + + # emit from admin + admin_client.emit( + 'emit', ('/', client1.sid, 'foo', {'bar': 'baz'}, 'extra')) + data = client1.receive(timeout=5) + assert data == ['foo', {'bar': 'baz'}, 'extra'] + + # emit from regular client + client1.emit('emit', 'foo') + data = client2.receive(timeout=5) + assert data == ['foo'] + + # join and leave + admin_client.emit('join', ('/', 'room', client1.sid)) + admin_client.emit( + 'emit', ('/', 'room', 'foo', {'bar': 'baz'})) + data = client1.receive(timeout=5) + assert data == ['foo', {'bar': 'baz'}] + admin_client.emit('leave', ('/', 'room')) + + # disconnect + admin_client.emit('_disconnect', ('/', False, client1.sid)) + for _ in range(10): + if not client1.connected: + break + time.sleep(0.2) + assert not client1.connected diff --git a/tests/common/test_simple_client.py b/tests/common/test_simple_client.py index 2a0b7b7..4069042 100644 --- a/tests/common/test_simple_client.py +++ b/tests/common/test_simple_client.py @@ -18,12 +18,12 @@ class TestSimpleClient(unittest.TestCase): client = SimpleClient(123, a='b') with mock.patch('socketio.simple_client.Client') as mock_client: client.connect('url', headers='h', auth='a', transports='t', - namespace='n', socketio_path='s') + namespace='n', socketio_path='s', wait_timeout='w') mock_client.assert_called_once_with(123, a='b') assert client.client == mock_client() mock_client().connect.assert_called_once_with( 'url', headers='h', auth='a', transports='t', - namespaces=['n'], socketio_path='s') + namespaces=['n'], socketio_path='s', wait_timeout='w') mock_client().event.call_count == 3 mock_client().on.called_once_with('*') assert client.namespace == 'n' @@ -33,12 +33,13 @@ class TestSimpleClient(unittest.TestCase): with SimpleClient(123, a='b') as client: with mock.patch('socketio.simple_client.Client') as mock_client: client.connect('url', headers='h', auth='a', transports='t', - namespace='n', socketio_path='s') + namespace='n', socketio_path='s', + wait_timeout='w') mock_client.assert_called_once_with(123, a='b') assert client.client == mock_client() mock_client().connect.assert_called_once_with( 'url', headers='h', auth='a', transports='t', - namespaces=['n'], socketio_path='s') + namespaces=['n'], socketio_path='s', wait_timeout='w') mock_client().event.call_count == 3 mock_client().on.called_once_with('*') assert client.namespace == 'n' @@ -54,7 +55,8 @@ class TestSimpleClient(unittest.TestCase): def test_properties(self): client = SimpleClient() - client.client = mock.MagicMock(sid='sid', transport='websocket') + client.client = mock.MagicMock(transport='websocket') + client.client.get_sid.return_value = 'sid' client.connected_event.set() client.connected = True diff --git a/tests/web_server.py b/tests/web_server.py new file mode 100644 index 0000000..cb24668 --- /dev/null +++ b/tests/web_server.py @@ -0,0 +1,81 @@ +import threading +import time +from socketserver import ThreadingMixIn +from wsgiref.simple_server import make_server, WSGIServer, WSGIRequestHandler +import requests +import socketio + + +class SocketIOWebServer: + """A simple web server used for running Socket.IO servers in tests. + + :param sio: a Socket.IO server instance. + + Note 1: This class is not production-ready and is intended for testing. + Note 2: This class only supports the "threading" async_mode, with WebSocket + support provided by the simple-websocket package. + """ + def __init__(self, sio): + if sio.async_mode != 'threading': + raise ValueError('The async_mode must be "threading"') + + def http_app(environ, start_response): + start_response('200 OK', [('Content-Type', 'text/plain')]) + return [b'OK'] + + self.sio = sio + self.app = socketio.WSGIApp(sio, http_app) + self.httpd = None + self.thread = None + + def start(self, port=8900): + """Start the web server. + + :param port: the port to listen on. Defaults to 8900. + + The server is started in a background thread. + """ + class ThreadingWSGIServer(ThreadingMixIn, WSGIServer): + pass + + class WebSocketRequestHandler(WSGIRequestHandler): + def get_environ(self): + env = super().get_environ() + + # pass the raw socket to the WSGI app so that it can be used + # by WebSocket connections (hack copied from gunicorn) + env['gunicorn.socket'] = self.connection + return env + + self.httpd = make_server('', port, self._app_wrapper, + ThreadingWSGIServer, WebSocketRequestHandler) + self.thread = threading.Thread(target=self.httpd.serve_forever) + self.thread.start() + + # wait for the server to start + while True: + try: + r = requests.get(f'http://localhost:{port}/') + r.raise_for_status() + if r.text == 'OK': + break + except: + time.sleep(0.1) + + def stop(self): + """Stop the web server.""" + self.sio.shutdown() + self.httpd.shutdown() + self.httpd.server_close() + self.thread.join() + self.httpd = None + self.thread = None + + def _app_wrapper(self, environ, start_response): + try: + return self.app(environ, start_response) + except StopIteration: + # end the WebSocket request without sending a response + # (this is a hack that was copied from gunicorn's threaded worker) + start_response('200 OK', []) + return [] diff --git a/tox.ini b/tox.ini index 12d2891..2929457 100644 --- a/tox.ini +++ b/tox.ini @@ -14,10 +14,16 @@ python = [testenv] commands= pip install -e . - pytest -p no:logging --cov=socketio --cov-branch --cov-report=term-missing --cov-report=xml + pytest -p no:logging --timeout=60 --cov=socketio --cov-branch --cov-report=term-missing --cov-report=xml deps= + simple-websocket + uvicorn + requests + websocket-client + aiohttp msgpack pytest + pytest-timeout pytest-cov [testenv:flake8]