You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

328 lines
13 KiB

import threading
import time
from functools import wraps
from unittest import mock
import pytest
try:
from engineio.async_socket import AsyncSocket as EngineIOSocket
except ImportError:
from engineio.asyncio_socket import AsyncSocket as EngineIOSocket
import socketio
from fastsio.exceptions import ConnectionError
from tests.asyncio_web_server import SocketIOWebServer
def with_instrumented_server(auth=False, **ikwargs):
"""This decorator can be applied to test functions or methods so that they
run with a Socket.IO server that has been instrumented for the official
Admin UI project. The arguments passed to the decorator are passed directly
to the ``instrument()`` method of the server.
"""
def decorator(f):
@wraps(f)
def wrapped(self, *args, **kwargs):
sio = socketio.AsyncServer(async_mode="asgi")
@sio.event
async def enter_room(sid, data):
await sio.enter_room(sid, data)
@sio.event
async def emit(sid, event):
await sio.emit(event, skip_sid=sid)
@sio.event(namespace="/foo")
def connect(sid, environ, auth):
pass
async def shutdown():
await self.isvr.shutdown()
await sio.shutdown()
if "server_stats_interval" not in ikwargs:
ikwargs["server_stats_interval"] = 0.25
self.isvr = sio.instrument(auth=auth, **ikwargs)
server = SocketIOWebServer(sio, on_shutdown=shutdown)
server.start()
# import logging
# logging.getLogger('engineio.client').setLevel(logging.DEBUG)
# logging.getLogger('socketio.client').setLevel(logging.DEBUG)
original_schedule_ping = EngineIOSocket.schedule_ping
EngineIOSocket.schedule_ping = mock.MagicMock()
try:
ret = f(self, *args, **kwargs)
finally:
server.stop()
self.isvr.uninstrument()
self.isvr = None
EngineIOSocket.schedule_ping = original_schedule_ping
# import logging
# logging.getLogger('engineio.client').setLevel(logging.NOTSET)
# logging.getLogger('socketio.client').setLevel(logging.NOTSET)
return ret
return wrapped
return decorator
def _custom_auth(auth):
return auth == {"foo": "bar"}
async def _async_custom_auth(auth):
return auth == {"foo": "bar"}
class TestAsyncAdmin:
def setup_method(self):
print("threads at start:", threading.enumerate())
self.thread_count = threading.active_count()
def teardown_method(self):
print("threads at end:", threading.enumerate())
assert self.thread_count == threading.active_count()
def _expect(self, expected, admin_client):
events = {}
while expected:
data = admin_client.receive(timeout=5)
if data[0] in expected:
if expected[data[0]] == 1:
events[data[0]] = data[1]
del expected[data[0]]
else:
expected[data[0]] -= 1
return events
def test_missing_auth(self):
sio = socketio.AsyncServer(async_mode="asgi")
with pytest.raises(ValueError):
sio.instrument()
@with_instrumented_server(auth=False)
def test_admin_connect_with_no_auth(self):
with socketio.SimpleClient() as admin_client:
admin_client.connect("http://localhost:8900", namespace="/admin")
with socketio.SimpleClient() as admin_client:
admin_client.connect(
"http://localhost:8900", namespace="/admin", auth={"foo": "bar"}
)
@with_instrumented_server(auth={"foo": "bar"})
def test_admin_connect_with_dict_auth(self):
with socketio.SimpleClient() as admin_client:
admin_client.connect(
"http://localhost:8900", namespace="/admin", auth={"foo": "bar"}
)
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect(
"http://localhost:8900", namespace="/admin", auth={"foo": "baz"}
)
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect("http://localhost:8900", namespace="/admin")
@with_instrumented_server(auth=[{"foo": "bar"}, {"u": "admin", "p": "secret"}])
def test_admin_connect_with_list_auth(self):
with socketio.SimpleClient() as admin_client:
admin_client.connect(
"http://localhost:8900", namespace="/admin", auth={"foo": "bar"}
)
with socketio.SimpleClient() as admin_client:
admin_client.connect(
"http://localhost:8900",
namespace="/admin",
auth={"u": "admin", "p": "secret"},
)
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect(
"http://localhost:8900", namespace="/admin", auth={"foo": "baz"}
)
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect("http://localhost:8900", namespace="/admin")
@with_instrumented_server(auth=_custom_auth)
def test_admin_connect_with_function_auth(self):
with socketio.SimpleClient() as admin_client:
admin_client.connect(
"http://localhost:8900", namespace="/admin", auth={"foo": "bar"}
)
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect(
"http://localhost:8900", namespace="/admin", auth={"foo": "baz"}
)
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect("http://localhost:8900", namespace="/admin")
@with_instrumented_server(auth=_async_custom_auth)
def test_admin_connect_with_async_function_auth(self):
with socketio.SimpleClient() as admin_client:
admin_client.connect(
"http://localhost:8900", namespace="/admin", auth={"foo": "bar"}
)
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect(
"http://localhost:8900", namespace="/admin", auth={"foo": "baz"}
)
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect("http://localhost:8900", namespace="/admin")
@with_instrumented_server()
def test_admin_connect_only_admin(self):
with socketio.SimpleClient() as admin_client:
admin_client.connect("http://localhost:8900", namespace="/admin")
sid = admin_client.sid
events = self._expect(
{"config": 1, "all_sockets": 1, "server_stats": 2}, admin_client
)
assert "supportedFeatures" in events["config"]
assert "ALL_EVENTS" in events["config"]["supportedFeatures"]
assert "AGGREGATED_EVENTS" in events["config"]["supportedFeatures"]
assert "EMIT" in events["config"]["supportedFeatures"]
assert len(events["all_sockets"]) == 1
assert events["all_sockets"][0]["id"] == sid
assert events["all_sockets"][0]["rooms"] == [sid]
assert events["server_stats"]["clientsCount"] == 1
assert events["server_stats"]["pollingClientsCount"] == 0
assert len(events["server_stats"]["namespaces"]) == 3
assert {"name": "/", "socketsCount": 0} in events["server_stats"]["namespaces"]
assert {"name": "/foo", "socketsCount": 0} in events["server_stats"][
"namespaces"
]
assert {"name": "/admin", "socketsCount": 1} in events["server_stats"][
"namespaces"
]
@with_instrumented_server()
def test_admin_connect_with_others(self):
with (
socketio.SimpleClient() as client1,
socketio.SimpleClient() as client2,
socketio.SimpleClient() as client3,
socketio.SimpleClient() as admin_client,
):
client1.connect("http://localhost:8900")
client1.emit("enter_room", "room")
sid1 = client1.sid
saved_check_for_upgrade = self.isvr._check_for_upgrade
self.isvr._check_for_upgrade = mock.AsyncMock()
client2.connect(
"http://localhost:8900", namespace="/foo", transports=["polling"]
)
sid2 = client2.sid
self.isvr._check_for_upgrade = saved_check_for_upgrade
client3.connect("http://localhost:8900", namespace="/admin")
sid3 = client3.sid
admin_client.connect("http://localhost:8900", namespace="/admin")
sid = admin_client.sid
events = self._expect(
{"config": 1, "all_sockets": 1, "server_stats": 2}, admin_client
)
assert "supportedFeatures" in events["config"]
assert "ALL_EVENTS" in events["config"]["supportedFeatures"]
assert "AGGREGATED_EVENTS" in events["config"]["supportedFeatures"]
assert "EMIT" in events["config"]["supportedFeatures"]
assert len(events["all_sockets"]) == 4
assert events["server_stats"]["clientsCount"] == 4
assert events["server_stats"]["pollingClientsCount"] == 1
assert len(events["server_stats"]["namespaces"]) == 3
assert {"name": "/", "socketsCount": 1} in events["server_stats"]["namespaces"]
assert {"name": "/foo", "socketsCount": 1} in events["server_stats"][
"namespaces"
]
assert {"name": "/admin", "socketsCount": 2} in events["server_stats"][
"namespaces"
]
for socket in events["all_sockets"]:
if socket["id"] == sid:
assert socket["rooms"] == [sid]
elif socket["id"] == sid1:
assert socket["rooms"] == [sid1, "room"]
elif socket["id"] == sid2:
assert socket["rooms"] == [sid2]
elif socket["id"] == sid3:
assert socket["rooms"] == [sid3]
@with_instrumented_server(mode="production", read_only=True)
def test_admin_connect_production(self):
with socketio.SimpleClient() as admin_client:
admin_client.connect("http://localhost:8900", namespace="/admin")
events = self._expect({"config": 1, "server_stats": 2}, admin_client)
assert "supportedFeatures" in events["config"]
assert "ALL_EVENTS" not in events["config"]["supportedFeatures"]
assert "AGGREGATED_EVENTS" in events["config"]["supportedFeatures"]
assert "EMIT" not in events["config"]["supportedFeatures"]
assert events["server_stats"]["clientsCount"] == 1
assert events["server_stats"]["pollingClientsCount"] == 0
assert len(events["server_stats"]["namespaces"]) == 3
assert {"name": "/", "socketsCount": 0} in events["server_stats"]["namespaces"]
assert {"name": "/foo", "socketsCount": 0} in events["server_stats"][
"namespaces"
]
assert {"name": "/admin", "socketsCount": 1} in events["server_stats"][
"namespaces"
]
@with_instrumented_server()
def test_admin_features(self):
with (
socketio.SimpleClient() as client1,
socketio.SimpleClient() as client2,
socketio.SimpleClient() as admin_client,
):
client1.connect("http://localhost:8900")
client2.connect("http://localhost:8900")
admin_client.connect("http://localhost:8900", namespace="/admin")
# emit from admin
admin_client.emit(
"emit", ("/", client1.sid, "foo", {"bar": "baz"}, "extra")
)
data = client1.receive(timeout=5)
assert data == ["foo", {"bar": "baz"}, "extra"]
# emit from regular client
client1.emit("emit", "foo")
data = client2.receive(timeout=5)
assert data == ["foo"]
# join and leave
admin_client.emit("join", ("/", "room", client1.sid))
time.sleep(0.2)
admin_client.emit("emit", ("/", "room", "foo", {"bar": "baz"}))
data = client1.receive(timeout=5)
assert data == ["foo", {"bar": "baz"}]
admin_client.emit("leave", ("/", "room"))
# disconnect
admin_client.emit("_disconnect", ("/", False, client1.sid))
for _ in range(10):
if not client1.connected:
break
time.sleep(0.2)
assert not client1.connected