Browse Source

tests

pull/1164/head
Miguel Grinberg 2 years ago
parent
commit
b37477706d
Failed to extract signature
  1. 2
      .github/workflows/tests.yml
  2. 70
      src/socketio/admin.py
  3. 10
      src/socketio/async_simple_client.py
  4. 78
      src/socketio/asyncio_admin.py
  5. 2
      src/socketio/base_manager.py
  6. 10
      src/socketio/simple_client.py
  7. 299
      tests/async/test_asyncio_admin.py
  8. 8
      tests/async/test_manager.py
  9. 12
      tests/async/test_simple_client.py
  10. 57
      tests/asyncio_web_server.py
  11. 277
      tests/common/test_admin.py
  12. 12
      tests/common/test_simple_client.py
  13. 81
      tests/web_server.py
  14. 8
      tox.ini

2
.github/workflows/tests.yml

@ -26,7 +26,7 @@ jobs:
exclude: exclude:
# pypy3 currently fails to run on Windows # pypy3 currently fails to run on Windows
- os: windows-latest - os: windows-latest
python: pypy-3.8 python: pypy-3.9
fail-fast: false fail-fast: false
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:

70
src/socketio/admin.py

@ -63,7 +63,7 @@ class InstrumentedServer:
namespace=self.admin_namespace) namespace=self.admin_namespace)
if self.mode == 'development': if self.mode == 'development':
if not self.read_only: if not self.read_only: # pragma: no branch
self.sio.on('emit', self.admin_emit, self.sio.on('emit', self.admin_emit,
namespace=self.admin_namespace) namespace=self.admin_namespace)
self.sio.on('join', self.admin_enter_room, self.sio.on('join', self.admin_enter_room,
@ -117,8 +117,22 @@ class InstrumentedServer:
Socket._websocket_handler = functools.partialmethod( Socket._websocket_handler = functools.partialmethod(
self.__class__._eio_websocket_handler, self) self.__class__._eio_websocket_handler, self)
def uninstrument(self): # pragma: no cover
if self.mode == 'development':
self.sio.manager.connect = self.sio.manager.__connect
self.sio.manager.disconnect = self.sio.manager.__disconnect
self.sio.manager.enter_room = self.sio.manager.__enter_room
self.sio.manager.leave_room = self.sio.manager.__leave_room
self.sio.manager.emit = self.sio.manager.__emit
self.sio._handle_event_internal = self.sio.__handle_event_internal
self.sio.eio._ok = self.sio.eio.__ok
from engineio.socket import Socket
Socket.handle_post_request = Socket.__handle_post_request
Socket._websocket_handler = Socket.__websocket_handler
def admin_connect(self, sid, environ, client_auth): def admin_connect(self, sid, environ, client_auth):
if self.auth != None: if self.auth:
authenticated = False authenticated = False
if isinstance(self.auth, dict): if isinstance(self.auth, dict):
authenticated = client_auth == self.auth authenticated = client_auth == self.auth
@ -175,8 +189,9 @@ class InstrumentedServer:
self.sio.disconnect(sid, namespace=namespace) self.sio.disconnect(sid, namespace=namespace)
def shutdown(self): def shutdown(self):
self.stop_stats_event.set() if self.stats_task: # pragma: no branch
self.stats_thread.join() self.stop_stats_event.set()
self.stats_task.join()
def _connect(self, eio_sid, namespace): def _connect(self, eio_sid, namespace):
sid = self.sio.manager.__connect(eio_sid, namespace) sid = self.sio.manager.__connect(eio_sid, namespace)
@ -188,22 +203,9 @@ class InstrumentedServer:
datetime.utcfromtimestamp(t).isoformat() + 'Z', datetime.utcfromtimestamp(t).isoformat() + 'Z',
), namespace=self.admin_namespace) ), namespace=self.admin_namespace)
def check_for_upgrade(): if serialized_socket['transport'] == 'polling': # pragma: no cover
for _ in range(5): self.sio.start_background_task(
self.sio.sleep(5) self._check_for_upgrade, eio_sid, sid, namespace)
try:
if self.sio.eio._get_socket(eio_sid).upgraded:
self.sio.emit('socket_updated', {
'id': sid,
'nsp': namespace,
'transport': 'websocket',
}, namespace=self.admin_namespace)
break
except KeyError:
pass
if serialized_socket['transport'] == 'polling':
self.sio.start_background_task(check_for_upgrade)
return sid return sid
def _disconnect(self, sid, namespace, **kwargs): def _disconnect(self, sid, namespace, **kwargs):
@ -216,6 +218,20 @@ class InstrumentedServer:
), namespace=self.admin_namespace) ), namespace=self.admin_namespace)
return self.sio.manager.__disconnect(sid, namespace, **kwargs) return self.sio.manager.__disconnect(sid, namespace, **kwargs)
def _check_for_upgrade(self, eio_sid, sid, namespace): # pragma: no cover
for _ in range(5):
self.sio.sleep(5)
try:
if self.sio.eio._get_socket(eio_sid).upgraded:
self.sio.emit('socket_updated', {
'id': sid,
'nsp': namespace,
'transport': 'websocket',
}, namespace=self.admin_namespace)
break
except KeyError:
pass
def _enter_room(self, sid, namespace, room, eio_sid=None): def _enter_room(self, sid, namespace, room, eio_sid=None):
ret = self.sio.manager.__enter_room(sid, namespace, room, eio_sid) ret = self.sio.manager.__enter_room(sid, namespace, room, eio_sid)
if room: if room:
@ -245,7 +261,7 @@ class InstrumentedServer:
if namespace != self.admin_namespace: if namespace != self.admin_namespace:
event_data = [event] + list(data) if isinstance(data, tuple) \ event_data = [event] + list(data) if isinstance(data, tuple) \
else [data] else [data]
if not isinstance(skip_sid, list): if not isinstance(skip_sid, list): # pragma: no branch
skip_sid = [skip_sid] skip_sid = [skip_sid]
for sid, _ in self.sio.manager.get_participants(namespace, room): for sid, _ in self.sio.manager.get_participants(namespace, room):
if sid not in skip_sid: if sid not in skip_sid:
@ -328,15 +344,17 @@ class InstrumentedServer:
'namespaces': [{ 'namespaces': [{
'name': nsp, 'name': nsp,
'socketsCount': len(self.sio.manager.rooms.get( 'socketsCount': len(self.sio.manager.rooms.get(
nsp, {None: []})[None]) nsp, {None: []}).get(None, []))
} for nsp in namespaces], } for nsp in namespaces],
}, namespace=self.admin_namespace) }, namespace=self.admin_namespace)
def serialize_socket(self, sid, namespace, eio_sid=None): def serialize_socket(self, sid, namespace, eio_sid=None):
if eio_sid is None: if eio_sid is None: # pragma: no cover
eio_sid = self.sio.manager.eio_sid_from_sid(sid) eio_sid = self.sio.manager.eio_sid_from_sid(sid)
socket = self.sio.eio._get_socket(eio_sid) socket = self.sio.eio._get_socket(eio_sid)
environ = self.sio.environ.get(eio_sid, {}) environ = self.sio.environ.get(eio_sid, {})
tm = self.sio.manager._timestamps[sid] if sid in \
self.sio.manager._timestamps else 0
return { return {
'id': sid, 'id': sid,
'clientId': eio_sid, 'clientId': eio_sid,
@ -351,9 +369,9 @@ class InstrumentedServer:
environ.get('QUERY_STRING', '')).items()}, environ.get('QUERY_STRING', '')).items()},
'secure': environ.get('wsgi.url_scheme', '') == 'https', 'secure': environ.get('wsgi.url_scheme', '') == 'https',
'url': environ.get('PATH_INFO', ''), 'url': environ.get('PATH_INFO', ''),
'issued': self.sio.manager._timestamps[sid] * 1000, 'issued': tm * 1000,
'time': datetime.utcfromtimestamp( 'time': datetime.utcfromtimestamp(tm).isoformat() + 'Z'
self.sio.manager._timestamps[sid]).isoformat() + 'Z', if tm else '',
}, },
'rooms': self.sio.manager.get_rooms(sid, namespace), 'rooms': self.sio.manager.get_rooms(sid, namespace),
} }

10
src/socketio/async_simple_client.py

@ -23,7 +23,8 @@ class AsyncSimpleClient:
self.input_buffer = [] self.input_buffer = []
async def connect(self, url, headers={}, auth=None, transports=None, async def connect(self, url, headers={}, auth=None, transports=None,
namespace='/', socketio_path='socket.io'): namespace='/', socketio_path='socket.io',
wait_timeout=5):
"""Connect to a Socket.IO server. """Connect to a Socket.IO server.
:param url: The URL of the Socket.IO server. It can include custom :param url: The URL of the Socket.IO server. It can include custom
@ -49,6 +50,8 @@ class AsyncSimpleClient:
:param socketio_path: The endpoint where the Socket.IO server is :param socketio_path: The endpoint where the Socket.IO server is
installed. The default value is appropriate for installed. The default value is appropriate for
most cases. most cases.
:param wait_timeout: How long the client should wait for the
connection. The default is 5 seconds.
Note: this method is a coroutine. Note: this method is a coroutine.
""" """
@ -80,7 +83,8 @@ class AsyncSimpleClient:
await self.client.connect( await self.client.connect(
url, headers=headers, auth=auth, transports=transports, url, headers=headers, auth=auth, transports=transports,
namespaces=[namespace], socketio_path=socketio_path) namespaces=[namespace], socketio_path=socketio_path,
wait_timeout=wait_timeout)
@property @property
def sid(self): def sid(self):
@ -89,7 +93,7 @@ class AsyncSimpleClient:
The session ID is not guaranteed to remain constant throughout the life The session ID is not guaranteed to remain constant throughout the life
of the connection, as reconnections can cause it to change. of the connection, as reconnections can cause it to change.
""" """
return self.client.sid if self.client else None return self.client.get_sid(self.namespace) if self.client else None
@property @property
def transport(self): def transport(self):

78
src/socketio/asyncio_admin.py

@ -44,7 +44,7 @@ class InstrumentedAsyncServer:
namespace=self.admin_namespace) namespace=self.admin_namespace)
if self.mode == 'development': if self.mode == 'development':
if not self.read_only: if not self.read_only: # pragma: no branch
self.sio.on('emit', self.admin_emit, self.sio.on('emit', self.admin_emit,
namespace=self.admin_namespace) namespace=self.admin_namespace)
self.sio.on('join', self.admin_enter_room, self.sio.on('join', self.admin_enter_room,
@ -89,7 +89,8 @@ class InstrumentedAsyncServer:
from engineio.asyncio_socket import AsyncSocket from engineio.asyncio_socket import AsyncSocket
self.sio.eio.__ok = self.sio.eio._ok self.sio.eio.__ok = self.sio.eio._ok
self.sio.eio._ok = self._eio_http_response self.sio.eio._ok = self._eio_http_response
AsyncSocket.__handle_post_request = functools.partialmethod( AsyncSocket.__handle_post_request = AsyncSocket.handle_post_request
AsyncSocket.handle_post_request = functools.partialmethod(
self.__class__._eio_handle_post_request, self) self.__class__._eio_handle_post_request, self)
# report websocket packets # report websocket packets
@ -97,9 +98,23 @@ class InstrumentedAsyncServer:
AsyncSocket._websocket_handler = functools.partialmethod( AsyncSocket._websocket_handler = functools.partialmethod(
self.__class__._eio_websocket_handler, self) self.__class__._eio_websocket_handler, self)
def uninstrument(self): # pragma: no cover
if self.mode == 'development':
self.sio.manager.connect = self.sio.manager.__connect
self.sio.manager.disconnect = self.sio.manager.__disconnect
self.sio.manager.enter_room = self.sio.manager.__enter_room
self.sio.manager.leave_room = self.sio.manager.__leave_room
self.sio.manager.emit = self.sio.manager.__emit
self.sio._handle_event_internal = self.sio.__handle_event_internal
self.sio.eio._ok = self.sio.eio.__ok
from engineio.asyncio_socket import AsyncSocket
AsyncSocket.handle_post_request = AsyncSocket.__handle_post_request
AsyncSocket._websocket_handler = AsyncSocket.__websocket_handler
async def admin_connect(self, sid, environ, client_auth): async def admin_connect(self, sid, environ, client_auth):
authenticated = True authenticated = True
if self.auth != None: if self.auth:
authenticated = False authenticated = False
if isinstance(self.auth, dict): if isinstance(self.auth, dict):
authenticated = client_auth == self.auth authenticated = client_auth == self.auth
@ -159,8 +174,9 @@ class InstrumentedAsyncServer:
await self.sio.disconnect(sid, namespace=namespace) await self.sio.disconnect(sid, namespace=namespace)
async def shutdown(self): async def shutdown(self):
self.stop_stats_event.set() if self.stats_task: # pragma: no branch
await asyncio.gather(self.stats_task) self.stop_stats_event.set()
await asyncio.gather(self.stats_task)
async def _connect(self, eio_sid, namespace): async def _connect(self, eio_sid, namespace):
sid = await self.sio.manager.__connect(eio_sid, namespace) sid = await self.sio.manager.__connect(eio_sid, namespace)
@ -172,22 +188,9 @@ class InstrumentedAsyncServer:
datetime.utcfromtimestamp(t).isoformat() + 'Z', datetime.utcfromtimestamp(t).isoformat() + 'Z',
), namespace=self.admin_namespace) ), namespace=self.admin_namespace)
async def check_for_upgrade():
for _ in range(5):
await self.sio.sleep(5)
try:
if self.sio.eio._get_socket(eio_sid).upgraded:
await self.sio.emit('socket_updated', {
'id': sid,
'nsp': namespace,
'transport': 'websocket',
}, namespace=self.admin_namespace)
break
except KeyError:
pass
if serialized_socket['transport'] == 'polling': if serialized_socket['transport'] == 'polling':
self.sio.start_background_task(check_for_upgrade) self.sio.start_background_task(
self._check_for_upgrade, eio_sid, sid, namespace)
return sid return sid
async def _disconnect(self, sid, namespace, **kwargs): async def _disconnect(self, sid, namespace, **kwargs):
@ -200,6 +203,21 @@ class InstrumentedAsyncServer:
), namespace=self.admin_namespace) ), namespace=self.admin_namespace)
return await self.sio.manager.__disconnect(sid, namespace, **kwargs) return await self.sio.manager.__disconnect(sid, namespace, **kwargs)
async def _check_for_upgrade(self, eio_sid, sid,
namespace): # pragma: no cover
for _ in range(5):
await self.sio.sleep(5)
try:
if self.sio.eio._get_socket(eio_sid).upgraded:
await self.sio.emit('socket_updated', {
'id': sid,
'nsp': namespace,
'transport': 'websocket',
}, namespace=self.admin_namespace)
break
except KeyError:
pass
def _enter_room(self, sid, namespace, room, eio_sid=None): def _enter_room(self, sid, namespace, room, eio_sid=None):
ret = self.sio.manager.__enter_room(sid, namespace, room, eio_sid) ret = self.sio.manager.__enter_room(sid, namespace, room, eio_sid)
if room: if room:
@ -223,13 +241,13 @@ class InstrumentedAsyncServer:
async def _emit(self, event, data, namespace, room=None, skip_sid=None, async def _emit(self, event, data, namespace, room=None, skip_sid=None,
callback=None, **kwargs): callback=None, **kwargs):
ret = await self.sio.manager.__emit(event, data, namespace, room=room, ret = await self.sio.manager.__emit(
skip_sid=skip_sid, callback=callback, event, data, namespace, room=room, skip_sid=skip_sid,
**kwargs) callback=callback, **kwargs)
if namespace != self.admin_namespace: if namespace != self.admin_namespace:
event_data = [event] + list(data) if isinstance(data, tuple) \ event_data = [event] + list(data) if isinstance(data, tuple) \
else [data] else [data]
if not isinstance(skip_sid, list): if not isinstance(skip_sid, list): # pragma: no branch
skip_sid = [skip_sid] skip_sid = [skip_sid]
for sid, _ in self.sio.manager.get_participants(namespace, room): for sid, _ in self.sio.manager.get_participants(namespace, room):
if sid not in skip_sid: if sid not in skip_sid:
@ -312,7 +330,7 @@ class InstrumentedAsyncServer:
'namespaces': [{ 'namespaces': [{
'name': nsp, 'name': nsp,
'socketsCount': len(self.sio.manager.rooms.get( 'socketsCount': len(self.sio.manager.rooms.get(
nsp, {None: []})[None]) nsp, {None: []}).get(None, []))
} for nsp in namespaces], } for nsp in namespaces],
}, namespace=self.admin_namespace) }, namespace=self.admin_namespace)
while self.admin_queue: while self.admin_queue:
@ -321,10 +339,12 @@ class InstrumentedAsyncServer:
namespace=self.admin_namespace) namespace=self.admin_namespace)
def serialize_socket(self, sid, namespace, eio_sid=None): def serialize_socket(self, sid, namespace, eio_sid=None):
if eio_sid is None: if eio_sid is None: # pragma: no cover
eio_sid = self.sio.manager.eio_sid_from_sid(sid) eio_sid = self.sio.manager.eio_sid_from_sid(sid)
socket = self.sio.eio._get_socket(eio_sid) socket = self.sio.eio._get_socket(eio_sid)
environ = self.sio.environ.get(eio_sid, {}) environ = self.sio.environ.get(eio_sid, {})
tm = self.sio.manager._timestamps[sid] if sid in \
self.sio.manager._timestamps else 0
return { return {
'id': sid, 'id': sid,
'clientId': eio_sid, 'clientId': eio_sid,
@ -339,9 +359,9 @@ class InstrumentedAsyncServer:
environ.get('QUERY_STRING', '')).items()}, environ.get('QUERY_STRING', '')).items()},
'secure': environ.get('wsgi.url_scheme', '') == 'https', 'secure': environ.get('wsgi.url_scheme', '') == 'https',
'url': environ.get('PATH_INFO', ''), 'url': environ.get('PATH_INFO', ''),
'issued': self.sio.manager._timestamps[sid] * 1000, 'issued': tm * 1000,
'time': datetime.utcfromtimestamp( 'time': datetime.utcfromtimestamp(tm).isoformat() + 'Z'
self.sio.manager._timestamps[sid]).isoformat() + 'Z', if tm else '',
}, },
'rooms': self.sio.manager.get_rooms(sid, namespace), 'rooms': self.sio.manager.get_rooms(sid, namespace),
} }

2
src/socketio/base_manager.py

@ -30,7 +30,7 @@ class BaseManager:
def get_participants(self, namespace, room): def get_participants(self, namespace, room):
"""Return an iterable with the active participants in a room.""" """Return an iterable with the active participants in a room."""
ns = self.rooms[namespace] ns = self.rooms.get(namespace, {})
if hasattr(room, '__len__') and not isinstance(room, str): if hasattr(room, '__len__') and not isinstance(room, str):
participants = ns[room[0]]._fwdm.copy() if room[0] in ns else {} participants = ns[room[0]]._fwdm.copy() if room[0] in ns else {}
for r in room[1:]: for r in room[1:]:

10
src/socketio/simple_client.py

@ -23,7 +23,7 @@ class SimpleClient:
self.input_buffer = [] self.input_buffer = []
def connect(self, url, headers={}, auth=None, transports=None, def connect(self, url, headers={}, auth=None, transports=None,
namespace='/', socketio_path='socket.io'): namespace='/', socketio_path='socket.io', wait_timeout=5):
"""Connect to a Socket.IO server. """Connect to a Socket.IO server.
:param url: The URL of the Socket.IO server. It can include custom :param url: The URL of the Socket.IO server. It can include custom
@ -49,6 +49,9 @@ class SimpleClient:
:param socketio_path: The endpoint where the Socket.IO server is :param socketio_path: The endpoint where the Socket.IO server is
installed. The default value is appropriate for installed. The default value is appropriate for
most cases. most cases.
:param wait_timeout: How long the client should wait for the
connection to be established. The default is 5
seconds.
""" """
if self.connected: if self.connected:
raise RuntimeError('Already connected') raise RuntimeError('Already connected')
@ -78,7 +81,8 @@ class SimpleClient:
self.client.connect(url, headers=headers, auth=auth, self.client.connect(url, headers=headers, auth=auth,
transports=transports, namespaces=[namespace], transports=transports, namespaces=[namespace],
socketio_path=socketio_path) socketio_path=socketio_path,
wait_timeout=wait_timeout)
@property @property
def sid(self): def sid(self):
@ -87,7 +91,7 @@ class SimpleClient:
The session ID is not guaranteed to remain constant throughout the life The session ID is not guaranteed to remain constant throughout the life
of the connection, as reconnections can cause it to change. of the connection, as reconnections can cause it to change.
""" """
return self.client.sid if self.client else None return self.client.get_sid(self.namespace) if self.client else None
@property @property
def transport(self): def transport(self):

299
tests/async/test_asyncio_admin.py

@ -0,0 +1,299 @@
from functools import wraps
import threading
import time
from unittest import mock
import unittest
import pytest
from engineio.asyncio_socket import AsyncSocket as EngineIOSocket
import socketio
from socketio.exceptions import ConnectionError
from tests.asyncio_web_server import SocketIOWebServer
from .helpers import AsyncMock
def with_instrumented_server(auth=False, **ikwargs):
"""This decorator can be applied to test functions or methods so that they
run with a Socket.IO server that has been instrumented for the official
Admin UI project. The arguments passed to the decorator are passed directly
to the ``instrument()`` method of the server.
"""
def decorator(f):
@wraps(f)
def wrapped(self, *args, **kwargs):
sio = socketio.AsyncServer(async_mode='asgi')
instrumented_server = sio.instrument(auth=auth, **ikwargs)
@sio.event
def enter_room(sid, data):
sio.enter_room(sid, data)
@sio.event
async def emit(sid, event):
await sio.emit(event, skip_sid=sid)
@sio.event(namespace='/foo')
def connect(sid, environ, auth):
pass
async def shutdown():
await instrumented_server.shutdown()
await sio.shutdown()
server = SocketIOWebServer(sio, on_shutdown=shutdown)
server.start()
# import logging
# logging.getLogger('engineio.client').setLevel(logging.DEBUG)
# logging.getLogger('socketio.client').setLevel(logging.DEBUG)
original_schedule_ping = EngineIOSocket.schedule_ping
EngineIOSocket.schedule_ping = mock.MagicMock()
try:
ret = f(self, instrumented_server, *args, **kwargs)
finally:
server.stop()
instrumented_server.uninstrument()
EngineIOSocket.schedule_ping = original_schedule_ping
# import logging
# logging.getLogger('engineio.client').setLevel(logging.NOTSET)
# logging.getLogger('socketio.client').setLevel(logging.NOTSET)
return ret
return wrapped
return decorator
def _custom_auth(auth):
return auth == {'foo': 'bar'}
async def _async_custom_auth(auth):
return auth == {'foo': 'bar'}
class TestAsyncAdmin(unittest.TestCase):
def setUp(self):
print('threads at start:', threading.enumerate())
self.thread_count = threading.active_count()
def tearDown(self):
print('threads at end:', threading.enumerate())
assert self.thread_count == threading.active_count()
def test_missing_auth(self):
sio = socketio.AsyncServer(async_mode='asgi')
with pytest.raises(ValueError):
sio.instrument()
@with_instrumented_server(auth=False)
def test_admin_connect_with_no_auth(self, isvr):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin')
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'foo': 'bar'})
@with_instrumented_server(auth={'foo': 'bar'})
def test_admin_connect_with_dict_auth(self, isvr):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'foo': 'bar'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect(
'http://localhost:8900', namespace='/admin',
auth={'foo': 'baz'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect(
'http://localhost:8900', namespace='/admin')
@with_instrumented_server(auth=[{'foo': 'bar'},
{'u': 'admin', 'p': 'secret'}])
def test_admin_connect_with_list_auth(self, isvr):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'foo': 'bar'})
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'u': 'admin', 'p': 'secret'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin', auth={'foo': 'baz'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin')
@with_instrumented_server(auth=_custom_auth)
def test_admin_connect_with_function_auth(self, isvr):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'foo': 'bar'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin', auth={'foo': 'baz'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin')
@with_instrumented_server(auth=_async_custom_auth)
def test_admin_connect_with_async_function_auth(self, isvr):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'foo': 'bar'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin', auth={'foo': 'baz'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin')
@with_instrumented_server()
def test_admin_connect_only_admin(self, isvr):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin')
sid = admin_client.sid
expected = ['config', 'all_sockets', 'server_stats']
events = {}
while expected:
data = admin_client.receive(timeout=5)
if data[0] in expected:
events[data[0]] = data[1]
expected.remove(data[0])
assert 'supportedFeatures' in events['config']
assert 'ALL_EVENTS' in events['config']['supportedFeatures']
assert len(events['all_sockets']) == 1
assert events['all_sockets'][0]['id'] == sid
assert events['all_sockets'][0]['rooms'] == [sid]
assert events['server_stats']['clientsCount'] == 1
assert events['server_stats']['pollingClientsCount'] == 0
assert len(events['server_stats']['namespaces']) == 3
assert {'name': '/', 'socketsCount': 0} in \
events['server_stats']['namespaces']
assert {'name': '/foo', 'socketsCount': 0} in \
events['server_stats']['namespaces']
assert {'name': '/admin', 'socketsCount': 1} in \
events['server_stats']['namespaces']
@with_instrumented_server()
def test_admin_connect_with_others(self, isvr):
with socketio.SimpleClient() as client1, \
socketio.SimpleClient() as client2, \
socketio.SimpleClient() as client3, \
socketio.SimpleClient() as admin_client:
client1.connect('http://localhost:8900')
client1.emit('enter_room', 'room')
sid1 = client1.sid
saved_check_for_upgrade = isvr._check_for_upgrade
isvr._check_for_upgrade = AsyncMock()
client2.connect('http://localhost:8900', namespace='/foo',
transports=['polling'])
sid2 = client2.sid
isvr._check_for_upgrade = saved_check_for_upgrade
client3.connect('http://localhost:8900', namespace='/admin')
sid3 = client3.sid
admin_client.connect('http://localhost:8900', namespace='/admin')
sid = admin_client.sid
expected = ['config', 'all_sockets', 'server_stats']
events = {}
while expected:
data = admin_client.receive(timeout=5)
if data[0] in expected:
events[data[0]] = data[1]
expected.remove(data[0])
assert 'supportedFeatures' in events['config']
assert 'ALL_EVENTS' in events['config']['supportedFeatures']
assert len(events['all_sockets']) == 4
assert events['server_stats']['clientsCount'] == 4
assert events['server_stats']['pollingClientsCount'] == 1
assert len(events['server_stats']['namespaces']) == 3
assert {'name': '/', 'socketsCount': 1} in \
events['server_stats']['namespaces']
assert {'name': '/foo', 'socketsCount': 1} in \
events['server_stats']['namespaces']
assert {'name': '/admin', 'socketsCount': 2} in \
events['server_stats']['namespaces']
for socket in events['all_sockets']:
if socket['id'] == sid:
assert socket['rooms'] == [sid]
elif socket['id'] == sid1:
assert socket['rooms'] == [sid1, 'room']
elif socket['id'] == sid2:
assert socket['rooms'] == [sid2]
elif socket['id'] == sid3:
assert socket['rooms'] == [sid3]
@with_instrumented_server(mode='production')
def test_admin_connect_production(self, isvr):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin')
expected = ['config', 'server_stats']
events = {}
while expected:
data = admin_client.receive(timeout=5)
if data[0] in expected:
events[data[0]] = data[1]
expected.remove(data[0])
assert 'supportedFeatures' in events['config']
assert 'ALL_EVENTS' not in events['config']['supportedFeatures']
assert events['server_stats']['clientsCount'] == 1
assert events['server_stats']['pollingClientsCount'] == 0
assert len(events['server_stats']['namespaces']) == 3
assert {'name': '/', 'socketsCount': 0} in \
events['server_stats']['namespaces']
assert {'name': '/foo', 'socketsCount': 0} in \
events['server_stats']['namespaces']
assert {'name': '/admin', 'socketsCount': 1} in \
events['server_stats']['namespaces']
@with_instrumented_server()
def test_admin_features(self, isvr):
with socketio.SimpleClient() as client1, \
socketio.SimpleClient() as client2, \
socketio.SimpleClient() as admin_client:
client1.connect('http://localhost:8900')
client2.connect('http://localhost:8900')
admin_client.connect('http://localhost:8900', namespace='/admin')
# emit from admin
admin_client.emit(
'emit', ('/', client1.sid, 'foo', {'bar': 'baz'}, 'extra'))
data = client1.receive(timeout=5)
assert data == ['foo', {'bar': 'baz'}, 'extra']
# emit from regular client
client1.emit('emit', 'foo')
data = client2.receive(timeout=5)
assert data == ['foo']
# join and leave
admin_client.emit('join', ('/', 'room', client1.sid))
admin_client.emit(
'emit', ('/', 'room', 'foo', {'bar': 'baz'}))
data = client1.receive(timeout=5)
assert data == ['foo', {'bar': 'baz'}]
admin_client.emit('leave', ('/', 'room'))
# disconnect
admin_client.emit('_disconnect', ('/', False, client1.sid))
for _ in range(10):
if not client1.connected:
break
time.sleep(0.2)
assert not client1.connected

8
tests/async/test_manager.py

@ -353,7 +353,7 @@ class TestAsyncManager(unittest.TestCase):
_run(self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo')) _run(self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo'))
def test_emit_with_tuple(self): def test_emit_with_tuple(self):
sid = self.bm.connect('123', '/foo') sid = _run(self.bm.connect('123', '/foo'))
_run( _run(
self.bm.emit( self.bm.emit(
'my event', ('foo', 'bar'), namespace='/foo', room=sid 'my event', ('foo', 'bar'), namespace='/foo', room=sid
@ -366,7 +366,7 @@ class TestAsyncManager(unittest.TestCase):
assert pkt.encode() == '42/foo,["my event","foo","bar"]' assert pkt.encode() == '42/foo,["my event","foo","bar"]'
def test_emit_with_list(self): def test_emit_with_list(self):
sid = self.bm.connect('123', '/foo') sid = _run(self.bm.connect('123', '/foo'))
_run( _run(
self.bm.emit( self.bm.emit(
'my event', ['foo', 'bar'], namespace='/foo', room=sid 'my event', ['foo', 'bar'], namespace='/foo', room=sid
@ -379,7 +379,7 @@ class TestAsyncManager(unittest.TestCase):
assert pkt.encode() == '42/foo,["my event",["foo","bar"]]' assert pkt.encode() == '42/foo,["my event",["foo","bar"]]'
def test_emit_with_none(self): def test_emit_with_none(self):
sid = self.bm.connect('123', '/foo') sid = _run(self.bm.connect('123', '/foo'))
_run( _run(
self.bm.emit( self.bm.emit(
'my event', None, namespace='/foo', room=sid 'my event', None, namespace='/foo', room=sid
@ -392,7 +392,7 @@ class TestAsyncManager(unittest.TestCase):
assert pkt.encode() == '42/foo,["my event"]' assert pkt.encode() == '42/foo,["my event"]'
def test_emit_binary(self): def test_emit_binary(self):
sid = self.bm.connect('123', '/') sid = _run(self.bm.connect('123', '/'))
_run( _run(
self.bm.emit( self.bm.emit(
u'my event', b'my binary data', namespace='/', room=sid u'my event', b'my binary data', namespace='/', room=sid

12
tests/async/test_simple_client.py

@ -24,12 +24,13 @@ class TestAsyncAsyncSimpleClient(unittest.TestCase):
mock_client.return_value.connect = AsyncMock() mock_client.return_value.connect = AsyncMock()
_run(client.connect('url', headers='h', auth='a', transports='t', _run(client.connect('url', headers='h', auth='a', transports='t',
namespace='n', socketio_path='s')) namespace='n', socketio_path='s',
wait_timeout='w'))
mock_client.assert_called_once_with(123, a='b') mock_client.assert_called_once_with(123, a='b')
assert client.client == mock_client() assert client.client == mock_client()
mock_client().connect.mock.assert_called_once_with( mock_client().connect.mock.assert_called_once_with(
'url', headers='h', auth='a', transports='t', 'url', headers='h', auth='a', transports='t',
namespaces=['n'], socketio_path='s') namespaces=['n'], socketio_path='s', wait_timeout='w')
mock_client().event.call_count == 3 mock_client().event.call_count == 3
mock_client().on.called_once_with('*') mock_client().on.called_once_with('*')
assert client.namespace == 'n' assert client.namespace == 'n'
@ -44,12 +45,12 @@ class TestAsyncAsyncSimpleClient(unittest.TestCase):
await client.connect('url', headers='h', auth='a', await client.connect('url', headers='h', auth='a',
transports='t', namespace='n', transports='t', namespace='n',
socketio_path='s') socketio_path='s', wait_timeout='w')
mock_client.assert_called_once_with(123, a='b') mock_client.assert_called_once_with(123, a='b')
assert client.client == mock_client() assert client.client == mock_client()
mock_client().connect.mock.assert_called_once_with( mock_client().connect.mock.assert_called_once_with(
'url', headers='h', auth='a', transports='t', 'url', headers='h', auth='a', transports='t',
namespaces=['n'], socketio_path='s') namespaces=['n'], socketio_path='s', wait_timeout='w')
mock_client().event.call_count == 3 mock_client().event.call_count == 3
mock_client().on.called_once_with('*') mock_client().on.called_once_with('*')
assert client.namespace == 'n' assert client.namespace == 'n'
@ -67,7 +68,8 @@ class TestAsyncAsyncSimpleClient(unittest.TestCase):
def test_properties(self): def test_properties(self):
client = AsyncSimpleClient() client = AsyncSimpleClient()
client.client = mock.MagicMock(sid='sid', transport='websocket') client.client = mock.MagicMock(transport='websocket')
client.client.get_sid.return_value = 'sid'
client.connected_event.set() client.connected_event.set()
client.connected = True client.connected = True

57
tests/asyncio_web_server.py

@ -0,0 +1,57 @@
import requests
import threading
import time
import uvicorn
import socketio
class SocketIOWebServer:
"""A simple web server used for running Socket.IO servers in tests.
:param sio: a Socket.IO server instance.
Note 1: This class is not production-ready and is intended for testing.
Note 2: This class only supports the "asgi" async_mode.
"""
def __init__(self, sio, on_shutdown=None):
if sio.async_mode != 'asgi':
raise ValueError('The async_mode must be "asgi"')
async def http_app(scope, receive, send):
await send({'type': 'http.response.start',
'status': 200,
'headers': [('Content-Type', 'text/plain')]})
await send({'type': 'http.response.body',
'body': b'OK'})
self.sio = sio
self.app = socketio.ASGIApp(sio, http_app, on_shutdown=on_shutdown)
self.httpd = None
self.thread = None
def start(self, port=8900):
"""Start the web server.
:param port: the port to listen on. Defaults to 8900.
The server is started in a background thread.
"""
self.httpd = uvicorn.Server(config=uvicorn.Config(self.app, port=port))
self.thread = threading.Thread(target=self.httpd.run)
self.thread.start()
# wait for the server to start
while True:
try:
r = requests.get(f'http://localhost:{port}/')
r.raise_for_status()
if r.text == 'OK':
break
except:
time.sleep(0.1)
def stop(self):
"""Stop the web server."""
self.httpd.should_exit = True
self.thread.join()
self.thread = None

277
tests/common/test_admin.py

@ -0,0 +1,277 @@
from functools import wraps
import threading
import time
from unittest import mock
import unittest
import pytest
from engineio.socket import Socket as EngineIOSocket
import socketio
from socketio.exceptions import ConnectionError
from tests.web_server import SocketIOWebServer
def with_instrumented_server(auth=False, **ikwargs):
"""This decorator can be applied to test functions or methods so that they
run with a Socket.IO server that has been instrumented for the official
Admin UI project. The arguments passed to the decorator are passed directly
to the ``instrument()`` method of the server.
"""
def decorator(f):
@wraps(f)
def wrapped(self, *args, **kwargs):
sio = socketio.Server(async_mode='threading')
instrumented_server = sio.instrument(auth=auth, **ikwargs)
@sio.event
def enter_room(sid, data):
sio.enter_room(sid, data)
@sio.event
def emit(sid, event):
sio.emit(event, skip_sid=sid)
@sio.event(namespace='/foo')
def connect(sid, environ, auth):
pass
server = SocketIOWebServer(sio)
server.start()
# import logging
# logging.getLogger('engineio.client').setLevel(logging.DEBUG)
# logging.getLogger('socketio.client').setLevel(logging.DEBUG)
original_schedule_ping = EngineIOSocket.schedule_ping
EngineIOSocket.schedule_ping = mock.MagicMock()
try:
ret = f(self, instrumented_server, *args, **kwargs)
finally:
server.stop()
instrumented_server.shutdown()
instrumented_server.uninstrument()
EngineIOSocket.schedule_ping = original_schedule_ping
# import logging
# logging.getLogger('engineio.client').setLevel(logging.NOTSET)
# logging.getLogger('socketio.client').setLevel(logging.NOTSET)
return ret
return wrapped
return decorator
def _custom_auth(auth):
return auth == {'foo': 'bar'}
class TestAdmin(unittest.TestCase):
def setUp(self):
print('threads at start:', threading.enumerate())
self.thread_count = threading.active_count()
def tearDown(self):
print('threads at end:', threading.enumerate())
assert self.thread_count == threading.active_count()
def test_missing_auth(self):
sio = socketio.Server(async_mode='threading')
with pytest.raises(ValueError):
sio.instrument()
@with_instrumented_server(auth=False)
def test_admin_connect_with_no_auth(self, isvr):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin')
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'foo': 'bar'})
@with_instrumented_server(auth={'foo': 'bar'})
def test_admin_connect_with_dict_auth(self, isvr):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'foo': 'bar'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect(
'http://localhost:8900', namespace='/admin',
auth={'foo': 'baz'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect(
'http://localhost:8900', namespace='/admin')
@with_instrumented_server(auth=[{'foo': 'bar'},
{'u': 'admin', 'p': 'secret'}])
def test_admin_connect_with_list_auth(self, isvr):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'foo': 'bar'})
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'u': 'admin', 'p': 'secret'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin', auth={'foo': 'baz'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin')
@with_instrumented_server(auth=_custom_auth)
def test_admin_connect_with_function_auth(self, isvr):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin',
auth={'foo': 'bar'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin', auth={'foo': 'baz'})
with socketio.SimpleClient() as admin_client:
with pytest.raises(ConnectionError):
admin_client.connect('http://localhost:8900',
namespace='/admin')
@with_instrumented_server()
def test_admin_connect_only_admin(self, isvr):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin')
sid = admin_client.sid
expected = ['config', 'all_sockets', 'server_stats']
events = {}
while expected:
data = admin_client.receive(timeout=5)
if data[0] in expected:
events[data[0]] = data[1]
expected.remove(data[0])
assert 'supportedFeatures' in events['config']
assert 'ALL_EVENTS' in events['config']['supportedFeatures']
assert len(events['all_sockets']) == 1
assert events['all_sockets'][0]['id'] == sid
assert events['all_sockets'][0]['rooms'] == [sid]
assert events['server_stats']['clientsCount'] == 1
assert events['server_stats']['pollingClientsCount'] == 0
assert len(events['server_stats']['namespaces']) == 3
assert {'name': '/', 'socketsCount': 0} in \
events['server_stats']['namespaces']
assert {'name': '/foo', 'socketsCount': 0} in \
events['server_stats']['namespaces']
assert {'name': '/admin', 'socketsCount': 1} in \
events['server_stats']['namespaces']
@with_instrumented_server()
def test_admin_connect_with_others(self, isvr):
with socketio.SimpleClient() as client1, \
socketio.SimpleClient() as client2, \
socketio.SimpleClient() as client3, \
socketio.SimpleClient() as admin_client:
client1.connect('http://localhost:8900')
client1.emit('enter_room', 'room')
sid1 = client1.sid
saved_check_for_upgrade = isvr._check_for_upgrade
isvr._check_for_upgrade = mock.MagicMock()
client2.connect('http://localhost:8900', namespace='/foo',
transports=['polling'])
sid2 = client2.sid
isvr._check_for_upgrade = saved_check_for_upgrade
client3.connect('http://localhost:8900', namespace='/admin')
sid3 = client3.sid
admin_client.connect('http://localhost:8900', namespace='/admin')
sid = admin_client.sid
expected = ['config', 'all_sockets', 'server_stats']
events = {}
while expected:
data = admin_client.receive(timeout=5)
if data[0] in expected:
events[data[0]] = data[1]
expected.remove(data[0])
assert 'supportedFeatures' in events['config']
assert 'ALL_EVENTS' in events['config']['supportedFeatures']
assert len(events['all_sockets']) == 4
assert events['server_stats']['clientsCount'] == 4
assert events['server_stats']['pollingClientsCount'] == 1
assert len(events['server_stats']['namespaces']) == 3
assert {'name': '/', 'socketsCount': 1} in \
events['server_stats']['namespaces']
assert {'name': '/foo', 'socketsCount': 1} in \
events['server_stats']['namespaces']
assert {'name': '/admin', 'socketsCount': 2} in \
events['server_stats']['namespaces']
for socket in events['all_sockets']:
if socket['id'] == sid:
assert socket['rooms'] == [sid]
elif socket['id'] == sid1:
assert socket['rooms'] == [sid1, 'room']
elif socket['id'] == sid2:
assert socket['rooms'] == [sid2]
elif socket['id'] == sid3:
assert socket['rooms'] == [sid3]
@with_instrumented_server(mode='production')
def test_admin_connect_production(self, isvr):
with socketio.SimpleClient() as admin_client:
admin_client.connect('http://localhost:8900', namespace='/admin')
expected = ['config', 'server_stats']
events = {}
while expected:
data = admin_client.receive(timeout=5)
if data[0] in expected:
events[data[0]] = data[1]
expected.remove(data[0])
assert 'supportedFeatures' in events['config']
assert 'ALL_EVENTS' not in events['config']['supportedFeatures']
assert events['server_stats']['clientsCount'] == 1
assert events['server_stats']['pollingClientsCount'] == 0
assert len(events['server_stats']['namespaces']) == 3
assert {'name': '/', 'socketsCount': 0} in \
events['server_stats']['namespaces']
assert {'name': '/foo', 'socketsCount': 0} in \
events['server_stats']['namespaces']
assert {'name': '/admin', 'socketsCount': 1} in \
events['server_stats']['namespaces']
@with_instrumented_server()
def test_admin_features(self, isvr):
with socketio.SimpleClient() as client1, \
socketio.SimpleClient() as client2, \
socketio.SimpleClient() as admin_client:
client1.connect('http://localhost:8900')
client2.connect('http://localhost:8900')
admin_client.connect('http://localhost:8900', namespace='/admin')
# emit from admin
admin_client.emit(
'emit', ('/', client1.sid, 'foo', {'bar': 'baz'}, 'extra'))
data = client1.receive(timeout=5)
assert data == ['foo', {'bar': 'baz'}, 'extra']
# emit from regular client
client1.emit('emit', 'foo')
data = client2.receive(timeout=5)
assert data == ['foo']
# join and leave
admin_client.emit('join', ('/', 'room', client1.sid))
admin_client.emit(
'emit', ('/', 'room', 'foo', {'bar': 'baz'}))
data = client1.receive(timeout=5)
assert data == ['foo', {'bar': 'baz'}]
admin_client.emit('leave', ('/', 'room'))
# disconnect
admin_client.emit('_disconnect', ('/', False, client1.sid))
for _ in range(10):
if not client1.connected:
break
time.sleep(0.2)
assert not client1.connected

12
tests/common/test_simple_client.py

@ -18,12 +18,12 @@ class TestSimpleClient(unittest.TestCase):
client = SimpleClient(123, a='b') client = SimpleClient(123, a='b')
with mock.patch('socketio.simple_client.Client') as mock_client: with mock.patch('socketio.simple_client.Client') as mock_client:
client.connect('url', headers='h', auth='a', transports='t', client.connect('url', headers='h', auth='a', transports='t',
namespace='n', socketio_path='s') namespace='n', socketio_path='s', wait_timeout='w')
mock_client.assert_called_once_with(123, a='b') mock_client.assert_called_once_with(123, a='b')
assert client.client == mock_client() assert client.client == mock_client()
mock_client().connect.assert_called_once_with( mock_client().connect.assert_called_once_with(
'url', headers='h', auth='a', transports='t', 'url', headers='h', auth='a', transports='t',
namespaces=['n'], socketio_path='s') namespaces=['n'], socketio_path='s', wait_timeout='w')
mock_client().event.call_count == 3 mock_client().event.call_count == 3
mock_client().on.called_once_with('*') mock_client().on.called_once_with('*')
assert client.namespace == 'n' assert client.namespace == 'n'
@ -33,12 +33,13 @@ class TestSimpleClient(unittest.TestCase):
with SimpleClient(123, a='b') as client: with SimpleClient(123, a='b') as client:
with mock.patch('socketio.simple_client.Client') as mock_client: with mock.patch('socketio.simple_client.Client') as mock_client:
client.connect('url', headers='h', auth='a', transports='t', client.connect('url', headers='h', auth='a', transports='t',
namespace='n', socketio_path='s') namespace='n', socketio_path='s',
wait_timeout='w')
mock_client.assert_called_once_with(123, a='b') mock_client.assert_called_once_with(123, a='b')
assert client.client == mock_client() assert client.client == mock_client()
mock_client().connect.assert_called_once_with( mock_client().connect.assert_called_once_with(
'url', headers='h', auth='a', transports='t', 'url', headers='h', auth='a', transports='t',
namespaces=['n'], socketio_path='s') namespaces=['n'], socketio_path='s', wait_timeout='w')
mock_client().event.call_count == 3 mock_client().event.call_count == 3
mock_client().on.called_once_with('*') mock_client().on.called_once_with('*')
assert client.namespace == 'n' assert client.namespace == 'n'
@ -54,7 +55,8 @@ class TestSimpleClient(unittest.TestCase):
def test_properties(self): def test_properties(self):
client = SimpleClient() client = SimpleClient()
client.client = mock.MagicMock(sid='sid', transport='websocket') client.client = mock.MagicMock(transport='websocket')
client.client.get_sid.return_value = 'sid'
client.connected_event.set() client.connected_event.set()
client.connected = True client.connected = True

81
tests/web_server.py

@ -0,0 +1,81 @@
import threading
import time
from socketserver import ThreadingMixIn
from wsgiref.simple_server import make_server, WSGIServer, WSGIRequestHandler
import requests
import socketio
class SocketIOWebServer:
"""A simple web server used for running Socket.IO servers in tests.
:param sio: a Socket.IO server instance.
Note 1: This class is not production-ready and is intended for testing.
Note 2: This class only supports the "threading" async_mode, with WebSocket
support provided by the simple-websocket package.
"""
def __init__(self, sio):
if sio.async_mode != 'threading':
raise ValueError('The async_mode must be "threading"')
def http_app(environ, start_response):
start_response('200 OK', [('Content-Type', 'text/plain')])
return [b'OK']
self.sio = sio
self.app = socketio.WSGIApp(sio, http_app)
self.httpd = None
self.thread = None
def start(self, port=8900):
"""Start the web server.
:param port: the port to listen on. Defaults to 8900.
The server is started in a background thread.
"""
class ThreadingWSGIServer(ThreadingMixIn, WSGIServer):
pass
class WebSocketRequestHandler(WSGIRequestHandler):
def get_environ(self):
env = super().get_environ()
# pass the raw socket to the WSGI app so that it can be used
# by WebSocket connections (hack copied from gunicorn)
env['gunicorn.socket'] = self.connection
return env
self.httpd = make_server('', port, self._app_wrapper,
ThreadingWSGIServer, WebSocketRequestHandler)
self.thread = threading.Thread(target=self.httpd.serve_forever)
self.thread.start()
# wait for the server to start
while True:
try:
r = requests.get(f'http://localhost:{port}/')
r.raise_for_status()
if r.text == 'OK':
break
except:
time.sleep(0.1)
def stop(self):
"""Stop the web server."""
self.sio.shutdown()
self.httpd.shutdown()
self.httpd.server_close()
self.thread.join()
self.httpd = None
self.thread = None
def _app_wrapper(self, environ, start_response):
try:
return self.app(environ, start_response)
except StopIteration:
# end the WebSocket request without sending a response
# (this is a hack that was copied from gunicorn's threaded worker)
start_response('200 OK', [])
return []

8
tox.ini

@ -14,10 +14,16 @@ python =
[testenv] [testenv]
commands= commands=
pip install -e . pip install -e .
pytest -p no:logging --cov=socketio --cov-branch --cov-report=term-missing --cov-report=xml pytest -p no:logging --timeout=60 --cov=socketio --cov-branch --cov-report=term-missing --cov-report=xml
deps= deps=
simple-websocket
uvicorn
requests
websocket-client
aiohttp
msgpack msgpack
pytest pytest
pytest-timeout
pytest-cov pytest-cov
[testenv:flake8] [testenv:flake8]

Loading…
Cancel
Save