import threading import time from functools import wraps from unittest import mock import pytest try: from engineio.async_socket import AsyncSocket as EngineIOSocket except ImportError: from engineio.asyncio_socket import AsyncSocket as EngineIOSocket import socketio from fastsio.exceptions import ConnectionError from tests.asyncio_web_server import SocketIOWebServer def with_instrumented_server(auth=False, **ikwargs): """This decorator can be applied to test functions or methods so that they run with a Socket.IO server that has been instrumented for the official Admin UI project. The arguments passed to the decorator are passed directly to the ``instrument()`` method of the server. """ def decorator(f): @wraps(f) def wrapped(self, *args, **kwargs): sio = socketio.AsyncServer(async_mode="asgi") @sio.event async def enter_room(sid, data): await sio.enter_room(sid, data) @sio.event async def emit(sid, event): await sio.emit(event, skip_sid=sid) @sio.event(namespace="/foo") def connect(sid, environ, auth): pass async def shutdown(): await self.isvr.shutdown() await sio.shutdown() if "server_stats_interval" not in ikwargs: ikwargs["server_stats_interval"] = 0.25 self.isvr = sio.instrument(auth=auth, **ikwargs) server = SocketIOWebServer(sio, on_shutdown=shutdown) server.start() # import logging # logging.getLogger('engineio.client').setLevel(logging.DEBUG) # logging.getLogger('socketio.client').setLevel(logging.DEBUG) original_schedule_ping = EngineIOSocket.schedule_ping EngineIOSocket.schedule_ping = mock.MagicMock() try: ret = f(self, *args, **kwargs) finally: server.stop() self.isvr.uninstrument() self.isvr = None EngineIOSocket.schedule_ping = original_schedule_ping # import logging # logging.getLogger('engineio.client').setLevel(logging.NOTSET) # logging.getLogger('socketio.client').setLevel(logging.NOTSET) return ret return wrapped return decorator def _custom_auth(auth): return auth == {"foo": "bar"} async def _async_custom_auth(auth): return auth == {"foo": "bar"} class TestAsyncAdmin: def setup_method(self): print("threads at start:", threading.enumerate()) self.thread_count = threading.active_count() def teardown_method(self): print("threads at end:", threading.enumerate()) assert self.thread_count == threading.active_count() def _expect(self, expected, admin_client): events = {} while expected: data = admin_client.receive(timeout=5) if data[0] in expected: if expected[data[0]] == 1: events[data[0]] = data[1] del expected[data[0]] else: expected[data[0]] -= 1 return events def test_missing_auth(self): sio = socketio.AsyncServer(async_mode="asgi") with pytest.raises(ValueError): sio.instrument() @with_instrumented_server(auth=False) def test_admin_connect_with_no_auth(self): with socketio.SimpleClient() as admin_client: admin_client.connect("http://localhost:8900", namespace="/admin") with socketio.SimpleClient() as admin_client: admin_client.connect( "http://localhost:8900", namespace="/admin", auth={"foo": "bar"} ) @with_instrumented_server(auth={"foo": "bar"}) def test_admin_connect_with_dict_auth(self): with socketio.SimpleClient() as admin_client: admin_client.connect( "http://localhost:8900", namespace="/admin", auth={"foo": "bar"} ) with socketio.SimpleClient() as admin_client: with pytest.raises(ConnectionError): admin_client.connect( "http://localhost:8900", namespace="/admin", auth={"foo": "baz"} ) with socketio.SimpleClient() as admin_client: with pytest.raises(ConnectionError): admin_client.connect("http://localhost:8900", namespace="/admin") @with_instrumented_server(auth=[{"foo": "bar"}, {"u": "admin", "p": "secret"}]) def test_admin_connect_with_list_auth(self): with socketio.SimpleClient() as admin_client: admin_client.connect( "http://localhost:8900", namespace="/admin", auth={"foo": "bar"} ) with socketio.SimpleClient() as admin_client: admin_client.connect( "http://localhost:8900", namespace="/admin", auth={"u": "admin", "p": "secret"}, ) with socketio.SimpleClient() as admin_client: with pytest.raises(ConnectionError): admin_client.connect( "http://localhost:8900", namespace="/admin", auth={"foo": "baz"} ) with socketio.SimpleClient() as admin_client: with pytest.raises(ConnectionError): admin_client.connect("http://localhost:8900", namespace="/admin") @with_instrumented_server(auth=_custom_auth) def test_admin_connect_with_function_auth(self): with socketio.SimpleClient() as admin_client: admin_client.connect( "http://localhost:8900", namespace="/admin", auth={"foo": "bar"} ) with socketio.SimpleClient() as admin_client: with pytest.raises(ConnectionError): admin_client.connect( "http://localhost:8900", namespace="/admin", auth={"foo": "baz"} ) with socketio.SimpleClient() as admin_client: with pytest.raises(ConnectionError): admin_client.connect("http://localhost:8900", namespace="/admin") @with_instrumented_server(auth=_async_custom_auth) def test_admin_connect_with_async_function_auth(self): with socketio.SimpleClient() as admin_client: admin_client.connect( "http://localhost:8900", namespace="/admin", auth={"foo": "bar"} ) with socketio.SimpleClient() as admin_client: with pytest.raises(ConnectionError): admin_client.connect( "http://localhost:8900", namespace="/admin", auth={"foo": "baz"} ) with socketio.SimpleClient() as admin_client: with pytest.raises(ConnectionError): admin_client.connect("http://localhost:8900", namespace="/admin") @with_instrumented_server() def test_admin_connect_only_admin(self): with socketio.SimpleClient() as admin_client: admin_client.connect("http://localhost:8900", namespace="/admin") sid = admin_client.sid events = self._expect( {"config": 1, "all_sockets": 1, "server_stats": 2}, admin_client ) assert "supportedFeatures" in events["config"] assert "ALL_EVENTS" in events["config"]["supportedFeatures"] assert "AGGREGATED_EVENTS" in events["config"]["supportedFeatures"] assert "EMIT" in events["config"]["supportedFeatures"] assert len(events["all_sockets"]) == 1 assert events["all_sockets"][0]["id"] == sid assert events["all_sockets"][0]["rooms"] == [sid] assert events["server_stats"]["clientsCount"] == 1 assert events["server_stats"]["pollingClientsCount"] == 0 assert len(events["server_stats"]["namespaces"]) == 3 assert {"name": "/", "socketsCount": 0} in events["server_stats"]["namespaces"] assert {"name": "/foo", "socketsCount": 0} in events["server_stats"][ "namespaces" ] assert {"name": "/admin", "socketsCount": 1} in events["server_stats"][ "namespaces" ] @with_instrumented_server() def test_admin_connect_with_others(self): with ( socketio.SimpleClient() as client1, socketio.SimpleClient() as client2, socketio.SimpleClient() as client3, socketio.SimpleClient() as admin_client, ): client1.connect("http://localhost:8900") client1.emit("enter_room", "room") sid1 = client1.sid saved_check_for_upgrade = self.isvr._check_for_upgrade self.isvr._check_for_upgrade = mock.AsyncMock() client2.connect( "http://localhost:8900", namespace="/foo", transports=["polling"] ) sid2 = client2.sid self.isvr._check_for_upgrade = saved_check_for_upgrade client3.connect("http://localhost:8900", namespace="/admin") sid3 = client3.sid admin_client.connect("http://localhost:8900", namespace="/admin") sid = admin_client.sid events = self._expect( {"config": 1, "all_sockets": 1, "server_stats": 2}, admin_client ) assert "supportedFeatures" in events["config"] assert "ALL_EVENTS" in events["config"]["supportedFeatures"] assert "AGGREGATED_EVENTS" in events["config"]["supportedFeatures"] assert "EMIT" in events["config"]["supportedFeatures"] assert len(events["all_sockets"]) == 4 assert events["server_stats"]["clientsCount"] == 4 assert events["server_stats"]["pollingClientsCount"] == 1 assert len(events["server_stats"]["namespaces"]) == 3 assert {"name": "/", "socketsCount": 1} in events["server_stats"]["namespaces"] assert {"name": "/foo", "socketsCount": 1} in events["server_stats"][ "namespaces" ] assert {"name": "/admin", "socketsCount": 2} in events["server_stats"][ "namespaces" ] for socket in events["all_sockets"]: if socket["id"] == sid: assert socket["rooms"] == [sid] elif socket["id"] == sid1: assert socket["rooms"] == [sid1, "room"] elif socket["id"] == sid2: assert socket["rooms"] == [sid2] elif socket["id"] == sid3: assert socket["rooms"] == [sid3] @with_instrumented_server(mode="production", read_only=True) def test_admin_connect_production(self): with socketio.SimpleClient() as admin_client: admin_client.connect("http://localhost:8900", namespace="/admin") events = self._expect({"config": 1, "server_stats": 2}, admin_client) assert "supportedFeatures" in events["config"] assert "ALL_EVENTS" not in events["config"]["supportedFeatures"] assert "AGGREGATED_EVENTS" in events["config"]["supportedFeatures"] assert "EMIT" not in events["config"]["supportedFeatures"] assert events["server_stats"]["clientsCount"] == 1 assert events["server_stats"]["pollingClientsCount"] == 0 assert len(events["server_stats"]["namespaces"]) == 3 assert {"name": "/", "socketsCount": 0} in events["server_stats"]["namespaces"] assert {"name": "/foo", "socketsCount": 0} in events["server_stats"][ "namespaces" ] assert {"name": "/admin", "socketsCount": 1} in events["server_stats"][ "namespaces" ] @with_instrumented_server() def test_admin_features(self): with ( socketio.SimpleClient() as client1, socketio.SimpleClient() as client2, socketio.SimpleClient() as admin_client, ): client1.connect("http://localhost:8900") client2.connect("http://localhost:8900") admin_client.connect("http://localhost:8900", namespace="/admin") # emit from admin admin_client.emit( "emit", ("/", client1.sid, "foo", {"bar": "baz"}, "extra") ) data = client1.receive(timeout=5) assert data == ["foo", {"bar": "baz"}, "extra"] # emit from regular client client1.emit("emit", "foo") data = client2.receive(timeout=5) assert data == ["foo"] # join and leave admin_client.emit("join", ("/", "room", client1.sid)) time.sleep(0.2) admin_client.emit("emit", ("/", "room", "foo", {"bar": "baz"})) data = client1.receive(timeout=5) assert data == ["foo", {"bar": "baz"}] admin_client.emit("leave", ("/", "room")) # disconnect admin_client.emit("_disconnect", ("/", False, client1.sid)) for _ in range(10): if not client1.connected: break time.sleep(0.2) assert not client1.connected