You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

308 lines
13 KiB

from functools import wraps
import threading
import time
from unittest import mock
import pytest
try:
from engineio.async_socket import AsyncSocket as EngineIOSocket
except ImportError:
from engineio.asyncio_socket import AsyncSocket as EngineIOSocket
import socketio
from socketio.exceptions import ConnectionError
from tests.asyncio_web_server import SocketIOWebServer
def with_instrumented_server(auth=False, **ikwargs):
"""This decorator can be applied to test functions or methods so that they
run with a Socket.IO server that has been instrumented for the official
Admin UI project. The arguments passed to the decorator are passed directly
to the ``instrument()`` method of the server.
"""
def decorator(f):
@wraps(f)
def wrapped(self, *args, **kwargs):
sio = socketio.AsyncServer(async_mode='asgi')
@sio.event
async def enter_room(sid, data):
await sio.enter_room(sid, data)
@sio.event
async def emit(sid, event):
await sio.emit(event, skip_sid=sid)
@sio.event(namespace='/foo')
def connect(sid, environ, auth):
pass
async def shutdown():
await self.isvr.shutdown()
await sio.shutdown()
if 'server_stats_interval' not in ikwargs:
ikwargs['server_stats_interval'] = 0.25
self.isvr = sio.instrument(auth=auth, **ikwargs)
server = SocketIOWebServer(sio, on_shutdown=shutdown)
server.start()
# import logging
# logging.getLogger('engineio.client').setLevel(logging.DEBUG)
# logging.getLogger('socketio.client').setLevel(logging.DEBUG)
original_schedule_ping = EngineIOSocket.schedule_ping
EngineIOSocket.schedule_ping = mock.MagicMock()
try:
ret = f(self, *args, **kwargs)
finally:
server.stop()
self.isvr.uninstrument()
self.isvr = None
EngineIOSocket.schedule_ping = original_schedule_ping
# import logging
# logging.getLogger('engineio.client').setLevel(logging.NOTSET)
# logging.getLogger('socketio.client').setLevel(logging.NOTSET)
return ret
return wrapped
return decorator
def _custom_auth(auth):
return auth == {'foo': 'bar'}
async def _async_custom_auth(auth):
return auth == {'foo': 'bar'}
class TestAsyncAdmin:
def setup_method(self):
print('threads at start:', threading.enumerate())
self.thread_count = threading.active_count()
def teardown_method(self):
print('threads at end:', threading.enumerate())
assert self.thread_count == threading.active_count()
def _expect(self, expected, admin_client):
events = {}
while expected:
data = admin_client.receive(timeout=5)
if data[0] in expected:
if expected[data[0]] == 1:
events[data[0]] = data[1]
del expected[data[0]]
else:
expected[data[0]] -= 1
return events
def test_missing_auth(self):
sio = socketio.AsyncServer(async_mode='asgi')
with pytest.raises(ValueError):
sio.instrument()
@with_instrumented_server(auth=False)
def test_admin_connect_with_no_auth(self):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin')
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'foo': 'bar'})
@with_instrumented_server(auth={'foo': 'bar'})
def test_admin_connect_with_dict_auth(self):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'foo': 'bar'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect(
'http://localhost:8900', namespace='/admin',
auth={'foo': 'baz'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect(
'http://localhost:8900', namespace='/admin')
@with_instrumented_server(auth=[{'foo': 'bar'},
{'u': 'admin', 'p': 'secret'}])
def test_admin_connect_with_list_auth(self):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'foo': 'bar'})
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'u': 'admin', 'p': 'secret'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin', auth={'foo': 'baz'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin')
@with_instrumented_server(auth=_custom_auth)
def test_admin_connect_with_function_auth(self):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'foo': 'bar'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin', auth={'foo': 'baz'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin')
@with_instrumented_server(auth=_async_custom_auth)
def test_admin_connect_with_async_function_auth(self):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'foo': 'bar'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin', auth={'foo': 'baz'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin')
@with_instrumented_server()
def test_admin_connect_only_admin(self):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin')
sid = admin_client.sid
events = self._expect({'config': 1, 'all_sockets': 1,
'server_stats': 2}, admin_client)
assert 'supportedFeatures' in events['config']
assert 'ALL_EVENTS' in events['config']['supportedFeatures']
assert 'AGGREGATED_EVENTS' in events['config']['supportedFeatures']
assert 'EMIT' in events['config']['supportedFeatures']
assert len(events['all_sockets']) == 1
assert events['all_sockets'][0]['id'] == sid
assert events['all_sockets'][0]['rooms'] == [sid]
assert events['server_stats']['clientsCount'] == 1
assert events['server_stats']['pollingClientsCount'] == 0
assert len(events['server_stats']['namespaces']) == 3
assert {'name': '/', 'socketsCount': 0} in \
events['server_stats']['namespaces']
assert {'name': '/foo', 'socketsCount': 0} in \
events['server_stats']['namespaces']
assert {'name': '/admin', 'socketsCount': 1} in \
events['server_stats']['namespaces']
@with_instrumented_server()
def test_admin_connect_with_others(self):
with socketio.SimpleClient() as client1, \
socketio.SimpleClient() as client2, \
socketio.SimpleClient() as client3, \
socketio.SimpleClient() as admin_client:
client1.connect('http://localhost:8900')
client1.emit('enter_room', 'room')
sid1 = client1.sid
saved_check_for_upgrade = self.isvr._check_for_upgrade
self.isvr._check_for_upgrade = mock.AsyncMock()
client2.connect('http://localhost:8900', namespace='/foo',
transports=['polling'])
sid2 = client2.sid
self.isvr._check_for_upgrade = saved_check_for_upgrade
client3.connect('http://localhost:8900', namespace='/admin')
sid3 = client3.sid
admin_client.connect('http://localhost:8900', namespace='/admin')
sid = admin_client.sid
events = self._expect({'config': 1, 'all_sockets': 1,
'server_stats': 2}, admin_client)
assert 'supportedFeatures' in events['config']
assert 'ALL_EVENTS' in events['config']['supportedFeatures']
assert 'AGGREGATED_EVENTS' in events['config']['supportedFeatures']
assert 'EMIT' in events['config']['supportedFeatures']
assert len(events['all_sockets']) == 4
assert events['server_stats']['clientsCount'] == 4
assert events['server_stats']['pollingClientsCount'] == 1
assert len(events['server_stats']['namespaces']) == 3
assert {'name': '/', 'socketsCount': 1} in \
events['server_stats']['namespaces']
assert {'name': '/foo', 'socketsCount': 1} in \
events['server_stats']['namespaces']
assert {'name': '/admin', 'socketsCount': 2} in \
events['server_stats']['namespaces']
for socket in events['all_sockets']:
if socket['id'] == sid:
assert socket['rooms'] == [sid]
elif socket['id'] == sid1:
assert socket['rooms'] == [sid1, 'room']
elif socket['id'] == sid2:
assert socket['rooms'] == [sid2]
elif socket['id'] == sid3:
assert socket['rooms'] == [sid3]
@with_instrumented_server(mode='production', read_only=True)
def test_admin_connect_production(self):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin')
events = self._expect({'config': 1, 'server_stats': 2},
admin_client)
assert 'supportedFeatures' in events['config']
assert 'ALL_EVENTS' not in events['config']['supportedFeatures']
assert 'AGGREGATED_EVENTS' in events['config']['supportedFeatures']
assert 'EMIT' not in events['config']['supportedFeatures']
assert events['server_stats']['clientsCount'] == 1
assert events['server_stats']['pollingClientsCount'] == 0
assert len(events['server_stats']['namespaces']) == 3
assert {'name': '/', 'socketsCount': 0} in \
events['server_stats']['namespaces']
assert {'name': '/foo', 'socketsCount': 0} in \
events['server_stats']['namespaces']
assert {'name': '/admin', 'socketsCount': 1} in \
events['server_stats']['namespaces']
@with_instrumented_server()
def test_admin_features(self):
with socketio.SimpleClient() as client1, \
socketio.SimpleClient() as client2, \
socketio.SimpleClient() as admin_client:
client1.connect('http://localhost:8900')
client2.connect('http://localhost:8900')
admin_client.connect('http://localhost:8900', namespace='/admin')
# emit from admin
admin_client.emit(
'emit', ('/', client1.sid, 'foo', {'bar': 'baz'}, 'extra'))
data = client1.receive(timeout=5)
assert data == ['foo', {'bar': 'baz'}, 'extra']
# emit from regular client
client1.emit('emit', 'foo')
data = client2.receive(timeout=5)
assert data == ['foo']
# join and leave
admin_client.emit('join', ('/', 'room', client1.sid))
time.sleep(0.2)
admin_client.emit(
'emit', ('/', 'room', 'foo', {'bar': 'baz'}))
data = client1.receive(timeout=5)
assert data == ['foo', {'bar': 'baz'}]
admin_client.emit('leave', ('/', 'room'))
# disconnect
admin_client.emit('_disconnect', ('/', False, client1.sid))
for _ in range(10):
if not client1.connected:
break
time.sleep(0.2)
assert not client1.connected