From bd8555da8523d1a73432685a00eb5acb4d2261f5 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Wed, 18 Dec 2024 17:39:03 +0000 Subject: [PATCH] Pass a `reason` argument to the disconnect handler (#1422) --- docs/client.rst | 23 +++-- docs/server.rst | 23 ++++- examples/client/async/fiddle_client.py | 4 +- examples/client/sync/fiddle_client.py | 4 +- examples/server/aiohttp/app.html | 4 +- examples/server/aiohttp/app.py | 6 +- examples/server/aiohttp/fiddle.py | 6 +- examples/server/asgi/app.html | 4 +- examples/server/asgi/app.py | 4 +- examples/server/asgi/fiddle.py | 4 +- examples/server/javascript/fiddle.js | 4 +- examples/server/sanic/app.html | 4 +- examples/server/sanic/app.py | 4 +- examples/server/sanic/fiddle.py | 4 +- examples/server/tornado/app.py | 4 +- examples/server/tornado/fiddle.py | 4 +- examples/server/tornado/templates/app.html | 4 +- examples/server/wsgi/app.py | 4 +- .../socketio_app/static/index.html | 4 +- .../django_socketio/socketio_app/views.py | 4 +- examples/server/wsgi/fiddle.py | 4 +- examples/server/wsgi/templates/index.html | 4 +- pyproject.toml | 2 +- src/socketio/async_client.py | 33 +++++-- src/socketio/async_namespace.py | 40 ++++++++- src/socketio/async_server.py | 33 +++++-- src/socketio/base_client.py | 5 +- src/socketio/base_server.py | 3 + src/socketio/client.py | 20 +++-- src/socketio/namespace.py | 18 +++- src/socketio/server.py | 24 +++-- tests/async/test_client.py | 89 ++++++++++--------- tests/async/test_manager.py | 2 - tests/async/test_namespace.py | 62 ++++++++++++- tests/async/test_server.py | 53 ++++++++--- tests/common/test_client.py | 72 +++++++-------- tests/common/test_namespace.py | 38 +++++++- tests/common/test_server.py | 38 +++++--- 38 files changed, 469 insertions(+), 193 deletions(-) diff --git a/docs/client.rst b/docs/client.rst index 1a55b71..e3e1fb2 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -312,8 +312,8 @@ server:: print("The connection failed!") @sio.event - def disconnect(): - print("I'm disconnected!") + def disconnect(reason): + print("I'm disconnected! reason:", reason) The ``connect_error`` handler is invoked when a connection attempt fails. If the server provides arguments, these are passed on to the handler. The server @@ -325,7 +325,20 @@ server initiated disconnects, or accidental disconnects, for example due to networking failures. In the case of an accidental disconnection, the client is going to attempt to reconnect immediately after invoking the disconnect handler. As soon as the connection is re-established the connect handler will -be invoked once again. +be invoked once again. The handler receives a ``reason`` argument which +provides the cause of the disconnection:: + + @sio.event + def disconnect(reason): + if reason == sio.reason.CLIENT_DISCONNECT: + print('the client disconnected') + elif reason == sio.reason.SERVER_DISCONNECT: + print('the server disconnected the client') + else: + print('disconnect reason:', reason) + +See the The :attr:`socketio.Client.reason` attribute for a list of possible +disconnection reasons. The ``connect``, ``connect_error`` and ``disconnect`` events have to be defined explicitly and are not invoked on a catch-all event handler. @@ -509,7 +522,7 @@ that belong to a namespace can be created as methods of a subclass of def on_connect(self): pass - def on_disconnect(self): + def on_disconnect(self, reason): pass def on_my_event(self, data): @@ -525,7 +538,7 @@ coroutines if desired:: def on_connect(self): pass - def on_disconnect(self): + def on_disconnect(self, reason): pass async def on_my_event(self, data): diff --git a/docs/server.rst b/docs/server.rst index c20adf9..ed15ed3 100644 --- a/docs/server.rst +++ b/docs/server.rst @@ -232,8 +232,8 @@ automatically when a client connects or disconnects from the server:: print('connect ', sid) @sio.event - def disconnect(sid): - print('disconnect ', sid) + def disconnect(sid, reason): + print('disconnect ', sid, reason) The ``connect`` event is an ideal place to perform user authentication, and any necessary mapping between user entities in the application and the ``sid`` @@ -256,6 +256,21 @@ message:: def connect(sid, environ, auth): raise ConnectionRefusedError('authentication failed') +The disconnect handler receives the ``sid`` assigned to the client and a +``reason``, which provides the cause of the disconnection:: + + @sio.event + def disconnect(sid, reason): + if reason == sio.reason.CLIENT_DISCONNECT: + print('the client disconnected') + elif reason == sio.reason.SERVER_DISCONNECT: + print('the server disconnected the client') + else: + print('disconnect reason:', reason) + +See the The :attr:`socketio.Server.reason` attribute for a list of possible +disconnection reasons. + Catch-All Event Handlers ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -433,7 +448,7 @@ belong to a namespace can be created as methods in a subclass of def on_connect(self, sid, environ): pass - def on_disconnect(self, sid): + def on_disconnect(self, sid, reason): pass def on_my_event(self, sid, data): @@ -449,7 +464,7 @@ if desired:: def on_connect(self, sid, environ): pass - def on_disconnect(self, sid): + def on_disconnect(self, sid, reason): pass async def on_my_event(self, sid, data): diff --git a/examples/client/async/fiddle_client.py b/examples/client/async/fiddle_client.py index 5b43dcc..e5aeb6c 100644 --- a/examples/client/async/fiddle_client.py +++ b/examples/client/async/fiddle_client.py @@ -10,8 +10,8 @@ async def connect(): @sio.event -async def disconnect(): - print('disconnected from server') +async def disconnect(reason): + print('disconnected from server, reason:', reason) @sio.event diff --git a/examples/client/sync/fiddle_client.py b/examples/client/sync/fiddle_client.py index 50f5e2a..71a7a54 100644 --- a/examples/client/sync/fiddle_client.py +++ b/examples/client/sync/fiddle_client.py @@ -9,8 +9,8 @@ def connect(): @sio.event -def disconnect(): - print('disconnected from server') +def disconnect(reason): + print('disconnected from server, reason:', reason) @sio.event diff --git a/examples/server/aiohttp/app.html b/examples/server/aiohttp/app.html index 74d404d..627b918 100644 --- a/examples/server/aiohttp/app.html +++ b/examples/server/aiohttp/app.html @@ -11,8 +11,8 @@ socket.on('connect', function() { socket.emit('my_event', {data: 'I\'m connected!'}); }); - socket.on('disconnect', function() { - $('#log').append('
Disconnected'); + socket.on('disconnect', function(reason) { + $('#log').append('
Disconnected: ' + reason); }); socket.on('my_response', function(msg) { $('#log').append('
Received: ' + msg.data); diff --git a/examples/server/aiohttp/app.py b/examples/server/aiohttp/app.py index cba5193..1568ca1 100644 --- a/examples/server/aiohttp/app.py +++ b/examples/server/aiohttp/app.py @@ -70,8 +70,8 @@ async def connect(sid, environ): @sio.event -def disconnect(sid): - print('Client disconnected') +def disconnect(sid, reason): + print('Client disconnected, reason:', reason) app.router.add_static('/static', 'static') @@ -84,4 +84,4 @@ async def init_app(): if __name__ == '__main__': - web.run_app(init_app()) + web.run_app(init_app(), port=5000) diff --git a/examples/server/aiohttp/fiddle.py b/examples/server/aiohttp/fiddle.py index dfde8e1..64ce330 100644 --- a/examples/server/aiohttp/fiddle.py +++ b/examples/server/aiohttp/fiddle.py @@ -19,8 +19,8 @@ async def connect(sid, environ, auth): @sio.event -def disconnect(sid): - print('disconnected', sid) +def disconnect(sid, reason): + print('disconnected', sid, reason) app.router.add_static('/static', 'static') @@ -28,4 +28,4 @@ app.router.add_get('/', index) if __name__ == '__main__': - web.run_app(app) + web.run_app(app, port=5000) diff --git a/examples/server/asgi/app.html b/examples/server/asgi/app.html index d2f0e9a..ad82656 100644 --- a/examples/server/asgi/app.html +++ b/examples/server/asgi/app.html @@ -11,8 +11,8 @@ socket.on('connect', function() { socket.emit('my_event', {data: 'I\'m connected!'}); }); - socket.on('disconnect', function() { - $('#log').append('
Disconnected'); + socket.on('disconnect', function(reason) { + $('#log').append('
Disconnected: ' + reason); }); socket.on('my_response', function(msg) { $('#log').append('
Received: ' + msg.data); diff --git a/examples/server/asgi/app.py b/examples/server/asgi/app.py index 36af85f..d549ab0 100644 --- a/examples/server/asgi/app.py +++ b/examples/server/asgi/app.py @@ -88,8 +88,8 @@ async def test_connect(sid, environ): @sio.on('disconnect') -def test_disconnect(sid): - print('Client disconnected') +def test_disconnect(sid, reason): + print('Client disconnected, reason:', reason) if __name__ == '__main__': diff --git a/examples/server/asgi/fiddle.py b/examples/server/asgi/fiddle.py index 6899ed1..402a379 100644 --- a/examples/server/asgi/fiddle.py +++ b/examples/server/asgi/fiddle.py @@ -17,8 +17,8 @@ async def connect(sid, environ, auth): @sio.event -def disconnect(sid): - print('disconnected', sid) +def disconnect(sid, reason): + print('disconnected', sid, reason) if __name__ == '__main__': diff --git a/examples/server/javascript/fiddle.js b/examples/server/javascript/fiddle.js index 940e4da..c6a039a 100644 --- a/examples/server/javascript/fiddle.js +++ b/examples/server/javascript/fiddle.js @@ -19,8 +19,8 @@ io.on('connection', socket => { hello: 'you' }); - socket.on('disconnect', () => { - console.log(`disconnect ${socket.id}`); + socket.on('disconnect', (reason) => { + console.log(`disconnect ${socket.id}, reason: ${reason}`); }); }); diff --git a/examples/server/sanic/app.html b/examples/server/sanic/app.html index 30c5964..b87b2df 100644 --- a/examples/server/sanic/app.html +++ b/examples/server/sanic/app.html @@ -11,8 +11,8 @@ socket.on('connect', function() { socket.emit('my_event', {data: 'I\'m connected!'}); }); - socket.on('disconnect', function() { - $('#log').append('
Disconnected'); + socket.on('disconnect', function(reason) { + $('#log').append('
Disconnected: ' + reason); }); socket.on('my_response', function(msg) { $('#log').append('
Received: ' + msg.data); diff --git a/examples/server/sanic/app.py b/examples/server/sanic/app.py index 7f02d23..447ddff 100644 --- a/examples/server/sanic/app.py +++ b/examples/server/sanic/app.py @@ -77,8 +77,8 @@ async def connect(sid, environ): @sio.event -def disconnect(sid): - print('Client disconnected') +def disconnect(sid, reason): + print('Client disconnected, reason:', reason) app.static('/static', './static') diff --git a/examples/server/sanic/fiddle.py b/examples/server/sanic/fiddle.py index 5ecb509..405e6e5 100644 --- a/examples/server/sanic/fiddle.py +++ b/examples/server/sanic/fiddle.py @@ -21,8 +21,8 @@ async def connect(sid, environ, auth): @sio.event -def disconnect(sid): - print('disconnected', sid) +def disconnect(sid, reason): + print('disconnected', sid, reason) app.static('/static', './static') diff --git a/examples/server/tornado/app.py b/examples/server/tornado/app.py index 16f7a19..58317d9 100644 --- a/examples/server/tornado/app.py +++ b/examples/server/tornado/app.py @@ -75,8 +75,8 @@ async def connect(sid, environ): @sio.event -def disconnect(sid): - print('Client disconnected') +def disconnect(sid, reason): + print('Client disconnected, reason:', reason) def main(): diff --git a/examples/server/tornado/fiddle.py b/examples/server/tornado/fiddle.py index 1e7e927..b3878a2 100644 --- a/examples/server/tornado/fiddle.py +++ b/examples/server/tornado/fiddle.py @@ -24,8 +24,8 @@ async def connect(sid, environ, auth): @sio.event -def disconnect(sid): - print('disconnected', sid) +def disconnect(sid, reason): + print('disconnected', sid, reason) def main(): diff --git a/examples/server/tornado/templates/app.html b/examples/server/tornado/templates/app.html index 74d404d..627b918 100644 --- a/examples/server/tornado/templates/app.html +++ b/examples/server/tornado/templates/app.html @@ -11,8 +11,8 @@ socket.on('connect', function() { socket.emit('my_event', {data: 'I\'m connected!'}); }); - socket.on('disconnect', function() { - $('#log').append('
Disconnected'); + socket.on('disconnect', function(reason) { + $('#log').append('
Disconnected: ' + reason); }); socket.on('my_response', function(msg) { $('#log').append('
Received: ' + msg.data); diff --git a/examples/server/wsgi/app.py b/examples/server/wsgi/app.py index 7b019fd..62bd59b 100644 --- a/examples/server/wsgi/app.py +++ b/examples/server/wsgi/app.py @@ -94,8 +94,8 @@ def connect(sid, environ): @sio.event -def disconnect(sid): - print('Client disconnected') +def disconnect(sid, reason): + print('Client disconnected, reason:', reason) if __name__ == '__main__': diff --git a/examples/server/wsgi/django_socketio/socketio_app/static/index.html b/examples/server/wsgi/django_socketio/socketio_app/static/index.html index 6dbef78..b10818f 100644 --- a/examples/server/wsgi/django_socketio/socketio_app/static/index.html +++ b/examples/server/wsgi/django_socketio/socketio_app/static/index.html @@ -11,8 +11,8 @@ socket.on('connect', function() { socket.emit('my_event', {data: 'I\'m connected!'}); }); - socket.on('disconnect', function() { - $('#log').append('
Disconnected'); + socket.on('disconnect', function(reason) { + $('#log').append('
Disconnected: ' + reason); }); socket.on('my_response', function(msg) { $('#log').append('
Received: ' + msg.data); diff --git a/examples/server/wsgi/django_socketio/socketio_app/views.py b/examples/server/wsgi/django_socketio/socketio_app/views.py index 854c0fb..f54e1d6 100644 --- a/examples/server/wsgi/django_socketio/socketio_app/views.py +++ b/examples/server/wsgi/django_socketio/socketio_app/views.py @@ -78,5 +78,5 @@ def connect(sid, environ): @sio.event -def disconnect(sid): - print('Client disconnected') +def disconnect(sid, reason): + print('Client disconnected, reason:', reason) diff --git a/examples/server/wsgi/fiddle.py b/examples/server/wsgi/fiddle.py index 247751b..e9cd703 100644 --- a/examples/server/wsgi/fiddle.py +++ b/examples/server/wsgi/fiddle.py @@ -23,8 +23,8 @@ def connect(sid, environ, auth): @sio.event -def disconnect(sid): - print('disconnected', sid) +def disconnect(sid, reason): + print('disconnected', sid, reason) if __name__ == '__main__': diff --git a/examples/server/wsgi/templates/index.html b/examples/server/wsgi/templates/index.html index 8a7308a..e37a6cb 100644 --- a/examples/server/wsgi/templates/index.html +++ b/examples/server/wsgi/templates/index.html @@ -11,8 +11,8 @@ socket.on('connect', function() { socket.emit('my_event', {data: 'I\'m connected!'}); }); - socket.on('disconnect', function() { - $('#log').append('
Disconnected'); + socket.on('disconnect', function(reason) { + $('#log').append('
Disconnected: ' + reason); }); socket.on('my_response', function(msg) { $('#log').append('
Received: ' + msg.data); diff --git a/pyproject.toml b/pyproject.toml index 8a5453a..a3ebb6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ requires-python = ">=3.8" dependencies = [ "bidict >= 0.21.0", - "python-engineio >= 4.8.0", + "python-engineio >= 4.11.0", ] [project.readme] diff --git a/src/socketio/async_client.py b/src/socketio/async_client.py index fb1abc1..463073e 100644 --- a/src/socketio/async_client.py +++ b/src/socketio/async_client.py @@ -338,7 +338,6 @@ class AsyncClient(base_client.BaseClient): await self.disconnect() elif self._reconnect_task: # pragma: no branch self._reconnect_abort.set() - print(self._reconnect_task) await self._reconnect_task def start_background_task(self, target, *args, **kwargs): @@ -398,8 +397,9 @@ class AsyncClient(base_client.BaseClient): if not self.connected: return namespace = namespace or '/' - await self._trigger_event('disconnect', namespace=namespace) - await self._trigger_event('__disconnect_final', namespace=namespace) + await self._trigger_event('disconnect', namespace, + self.reason.SERVER_DISCONNECT) + await self._trigger_event('__disconnect_final', namespace) if namespace in self.namespaces: del self.namespaces[namespace] if not self.namespaces: @@ -462,11 +462,27 @@ class AsyncClient(base_client.BaseClient): if handler: if asyncio.iscoroutinefunction(handler): try: - ret = await handler(*args) + try: + ret = await handler(*args) + except TypeError: + # the legacy disconnect event does not take a reason + # argument + if event == 'disconnect': + ret = await handler(*args[:-1]) + else: # pragma: no cover + raise except asyncio.CancelledError: # pragma: no cover ret = None else: - ret = handler(*args) + try: + ret = handler(*args) + except TypeError: + # the legacy disconnect event does not take a reason + # argument + if event == 'disconnect': + ret = handler(*args[:-1]) + else: # pragma: no cover + raise return ret # or else, forward the event to a namepsace handler if one exists @@ -566,16 +582,15 @@ class AsyncClient(base_client.BaseClient): else: raise ValueError('Unknown packet type.') - async def _handle_eio_disconnect(self): + async def _handle_eio_disconnect(self, reason): """Handle the Engine.IO disconnection event.""" self.logger.info('Engine.IO connection dropped') will_reconnect = self.reconnection and self.eio.state == 'connected' if self.connected: for n in self.namespaces: - await self._trigger_event('disconnect', namespace=n) + await self._trigger_event('disconnect', n, reason) if not will_reconnect: - await self._trigger_event('__disconnect_final', - namespace=n) + await self._trigger_event('__disconnect_final', n) self.namespaces = {} self.connected = False self.callbacks = {} diff --git a/src/socketio/async_namespace.py b/src/socketio/async_namespace.py index 89442ae..42d6508 100644 --- a/src/socketio/async_namespace.py +++ b/src/socketio/async_namespace.py @@ -34,11 +34,27 @@ class AsyncNamespace(base_namespace.BaseServerNamespace): handler = getattr(self, handler_name) if asyncio.iscoroutinefunction(handler) is True: try: - ret = await handler(*args) + try: + ret = await handler(*args) + except TypeError: + # legacy disconnect events do not have a reason + # argument + if event == 'disconnect': + ret = await handler(*args[:-1]) + else: # pragma: no cover + raise except asyncio.CancelledError: # pragma: no cover ret = None else: - ret = handler(*args) + try: + ret = handler(*args) + except TypeError: + # legacy disconnect events do not have a reason + # argument + if event == 'disconnect': + ret = handler(*args[:-1]) + else: # pragma: no cover + raise return ret async def emit(self, event, data=None, to=None, room=None, skip_sid=None, @@ -199,11 +215,27 @@ class AsyncClientNamespace(base_namespace.BaseClientNamespace): handler = getattr(self, handler_name) if asyncio.iscoroutinefunction(handler) is True: try: - ret = await handler(*args) + try: + ret = await handler(*args) + except TypeError: + # legacy disconnect events do not have a reason + # argument + if event == 'disconnect': + ret = await handler(*args[:-1]) + else: # pragma: no cover + raise except asyncio.CancelledError: # pragma: no cover ret = None else: - ret = handler(*args) + try: + ret = handler(*args) + except TypeError: + # legacy disconnect events do not have a reason + # argument + if event == 'disconnect': + ret = handler(*args[:-1]) + else: # pragma: no cover + raise return ret async def emit(self, event, data=None, namespace=None, callback=None): diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index 9b0e977..f10fb8a 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -427,7 +427,8 @@ class AsyncServer(base_server.BaseServer): eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) await self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, namespace=namespace)) - await self._trigger_event('disconnect', namespace, sid) + await self._trigger_event('disconnect', namespace, sid, + self.reason.SERVER_DISCONNECT) await self.manager.disconnect(sid, namespace=namespace, ignore_queue=True) @@ -575,14 +576,15 @@ class AsyncServer(base_server.BaseServer): await self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) - async def _handle_disconnect(self, eio_sid, namespace): + async def _handle_disconnect(self, eio_sid, namespace, reason=None): """Handle a client disconnect.""" namespace = namespace or '/' sid = self.manager.sid_from_eio_sid(eio_sid, namespace) if not self.manager.is_connected(sid, namespace): # pragma: no cover return self.manager.pre_disconnect(sid, namespace=namespace) - await self._trigger_event('disconnect', namespace, sid) + await self._trigger_event('disconnect', namespace, sid, + reason or self.reason.CLIENT_DISCONNECT) await self.manager.disconnect(sid, namespace, ignore_queue=True) async def _handle_event(self, eio_sid, namespace, id, data): @@ -634,11 +636,25 @@ class AsyncServer(base_server.BaseServer): if handler: if asyncio.iscoroutinefunction(handler): try: - ret = await handler(*args) + try: + ret = await handler(*args) + except TypeError: + # legacy disconnect events use only one argument + if event == 'disconnect': + ret = await handler(*args[:-1]) + else: # pragma: no cover + raise except asyncio.CancelledError: # pragma: no cover ret = None else: - ret = handler(*args) + try: + ret = handler(*args) + except TypeError: + # legacy disconnect events use only one argument + if event == 'disconnect': + ret = handler(*args[:-1]) + else: # pragma: no cover + raise return ret # or else, forward the event to a namespace handler if one exists handler, args = self._get_namespace_handler(namespace, args) @@ -671,7 +687,8 @@ class AsyncServer(base_server.BaseServer): if pkt.packet_type == packet.CONNECT: await self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: - await self._handle_disconnect(eio_sid, pkt.namespace) + await self._handle_disconnect(eio_sid, pkt.namespace, + self.reason.CLIENT_DISCONNECT) elif pkt.packet_type == packet.EVENT: await self._handle_event(eio_sid, pkt.namespace, pkt.id, pkt.data) @@ -686,10 +703,10 @@ class AsyncServer(base_server.BaseServer): else: raise ValueError('Unknown packet type.') - async def _handle_eio_disconnect(self, eio_sid): + async def _handle_eio_disconnect(self, eio_sid, reason): """Handle Engine.IO disconnect event.""" for n in list(self.manager.get_namespaces()).copy(): - await self._handle_disconnect(eio_sid, n) + await self._handle_disconnect(eio_sid, n, reason) if eio_sid in self.environ: del self.environ[eio_sid] diff --git a/src/socketio/base_client.py b/src/socketio/base_client.py index 1becf91..7bf4420 100644 --- a/src/socketio/base_client.py +++ b/src/socketio/base_client.py @@ -3,6 +3,8 @@ import logging import signal import threading +import engineio + from . import base_namespace from . import packet @@ -31,6 +33,7 @@ original_signal_handler = None class BaseClient: reserved_events = ['connect', 'connect_error', 'disconnect', '__disconnect_final'] + reason = engineio.Client.reason def __init__(self, reconnection=True, reconnection_attempts=0, reconnection_delay=1, reconnection_delay_max=5, @@ -285,7 +288,7 @@ class BaseClient: def _handle_eio_message(self, data): # pragma: no cover raise NotImplementedError() - def _handle_eio_disconnect(self): # pragma: no cover + def _handle_eio_disconnect(self, reason): # pragma: no cover raise NotImplementedError() def _engineio_client_class(self): # pragma: no cover diff --git a/src/socketio/base_server.py b/src/socketio/base_server.py index d5a353b..d134eba 100644 --- a/src/socketio/base_server.py +++ b/src/socketio/base_server.py @@ -1,5 +1,7 @@ import logging +import engineio + from . import manager from . import base_namespace from . import packet @@ -9,6 +11,7 @@ default_logger = logging.getLogger('socketio.server') class BaseServer: reserved_events = ['connect', 'disconnect'] + reason = engineio.Server.reason def __init__(self, client_manager=None, logger=False, serializer='default', json=None, async_handlers=True, always_connect=False, diff --git a/src/socketio/client.py b/src/socketio/client.py index c4f9eaa..ade2dd6 100644 --- a/src/socketio/client.py +++ b/src/socketio/client.py @@ -377,8 +377,9 @@ class Client(base_client.BaseClient): if not self.connected: return namespace = namespace or '/' - self._trigger_event('disconnect', namespace=namespace) - self._trigger_event('__disconnect_final', namespace=namespace) + self._trigger_event('disconnect', namespace, + self.reason.SERVER_DISCONNECT) + self._trigger_event('__disconnect_final', namespace) if namespace in self.namespaces: del self.namespaces[namespace] if not self.namespaces: @@ -436,7 +437,14 @@ class Client(base_client.BaseClient): # first see if we have an explicit handler for the event handler, args = self._get_event_handler(event, namespace, args) if handler: - return handler(*args) + try: + return handler(*args) + except TypeError: + # the legacy disconnect event does not take a reason argument + if event == 'disconnect': + return handler(*args[:-1]) + else: # pragma: no cover + raise # or else, forward the event to a namespace handler if one exists handler, args = self._get_namespace_handler(namespace, args) @@ -525,15 +533,15 @@ class Client(base_client.BaseClient): else: raise ValueError('Unknown packet type.') - def _handle_eio_disconnect(self): + def _handle_eio_disconnect(self, reason): """Handle the Engine.IO disconnection event.""" self.logger.info('Engine.IO connection dropped') will_reconnect = self.reconnection and self.eio.state == 'connected' if self.connected: for n in self.namespaces: - self._trigger_event('disconnect', namespace=n) + self._trigger_event('disconnect', n, reason) if not will_reconnect: - self._trigger_event('__disconnect_final', namespace=n) + self._trigger_event('__disconnect_final', n) self.namespaces = {} self.connected = False self.callbacks = {} diff --git a/src/socketio/namespace.py b/src/socketio/namespace.py index 3bf4f95..60cab78 100644 --- a/src/socketio/namespace.py +++ b/src/socketio/namespace.py @@ -23,7 +23,14 @@ class Namespace(base_namespace.BaseServerNamespace): """ handler_name = 'on_' + (event or '') if hasattr(self, handler_name): - return getattr(self, handler_name)(*args) + try: + return getattr(self, handler_name)(*args) + except TypeError: + # legacy disconnect events do not have a reason argument + if event == 'disconnect': + return getattr(self, handler_name)(*args[:-1]) + else: # pragma: no cover + raise def emit(self, event, data=None, to=None, room=None, skip_sid=None, namespace=None, callback=None, ignore_queue=False): @@ -154,7 +161,14 @@ class ClientNamespace(base_namespace.BaseClientNamespace): """ handler_name = 'on_' + (event or '') if hasattr(self, handler_name): - return getattr(self, handler_name)(*args) + try: + return getattr(self, handler_name)(*args) + except TypeError: + # legacy disconnect events do not have a reason argument + if event == 'disconnect': + return getattr(self, handler_name)(*args[:-1]) + else: # pragma: no cover + raise def emit(self, event, data=None, namespace=None, callback=None): """Emit a custom event to the server. diff --git a/src/socketio/server.py b/src/socketio/server.py index ae73df6..71c702d 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -403,7 +403,8 @@ class Server(base_server.BaseServer): eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, namespace=namespace)) - self._trigger_event('disconnect', namespace, sid) + self._trigger_event('disconnect', namespace, sid, + self.reason.SERVER_DISCONNECT) self.manager.disconnect(sid, namespace=namespace, ignore_queue=True) @@ -557,14 +558,15 @@ class Server(base_server.BaseServer): self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) - def _handle_disconnect(self, eio_sid, namespace): + def _handle_disconnect(self, eio_sid, namespace, reason=None): """Handle a client disconnect.""" namespace = namespace or '/' sid = self.manager.sid_from_eio_sid(eio_sid, namespace) if not self.manager.is_connected(sid, namespace): # pragma: no cover return self.manager.pre_disconnect(sid, namespace=namespace) - self._trigger_event('disconnect', namespace, sid) + self._trigger_event('disconnect', namespace, sid, + reason or self.reason.CLIENT_DISCONNECT) self.manager.disconnect(sid, namespace, ignore_queue=True) def _handle_event(self, eio_sid, namespace, id, data): @@ -611,7 +613,14 @@ class Server(base_server.BaseServer): # first see if we have an explicit handler for the event handler, args = self._get_event_handler(event, namespace, args) if handler: - return handler(*args) + try: + return handler(*args) + except TypeError: + # legacy disconnect events use only one argument + if event == 'disconnect': + return handler(*args[:-1]) + else: # pragma: no cover + raise # or else, forward the event to a namespace handler if one exists handler, args = self._get_namespace_handler(namespace, args) if handler: @@ -642,7 +651,8 @@ class Server(base_server.BaseServer): if pkt.packet_type == packet.CONNECT: self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: - self._handle_disconnect(eio_sid, pkt.namespace) + self._handle_disconnect(eio_sid, pkt.namespace, + self.reason.CLIENT_DISCONNECT) elif pkt.packet_type == packet.EVENT: self._handle_event(eio_sid, pkt.namespace, pkt.id, pkt.data) elif pkt.packet_type == packet.ACK: @@ -655,10 +665,10 @@ class Server(base_server.BaseServer): else: raise ValueError('Unknown packet type.') - def _handle_eio_disconnect(self, eio_sid): + def _handle_eio_disconnect(self, eio_sid, reason): """Handle Engine.IO disconnect event.""" for n in list(self.manager.get_namespaces()).copy(): - self._handle_disconnect(eio_sid, n) + self._handle_disconnect(eio_sid, n, reason) if eio_sid in self.environ: del self.environ[eio_sid] diff --git a/tests/async/test_client.py b/tests/async/test_client.py index 26681b7..3eec1a8 100644 --- a/tests/async/test_client.py +++ b/tests/async/test_client.py @@ -578,11 +578,9 @@ class TestAsyncClient: c._trigger_event = mock.AsyncMock() await c._handle_disconnect('/') c._trigger_event.assert_any_await( - 'disconnect', namespace='/' - ) - c._trigger_event.assert_any_await( - '__disconnect_final', namespace='/' + 'disconnect', '/', c.reason.SERVER_DISCONNECT ) + c._trigger_event.assert_any_await('__disconnect_final', '/') assert not c.connected await c._handle_disconnect('/') assert c._trigger_event.await_count == 2 @@ -593,21 +591,15 @@ class TestAsyncClient: c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.AsyncMock() await c._handle_disconnect('/foo') - c._trigger_event.assert_any_await( - 'disconnect', namespace='/foo' - ) - c._trigger_event.assert_any_await( - '__disconnect_final', namespace='/foo' - ) + c._trigger_event.assert_any_await('disconnect', '/foo', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_await('__disconnect_final', '/foo') assert c.namespaces == {'/bar': '2'} assert c.connected await c._handle_disconnect('/bar') - c._trigger_event.assert_any_await( - 'disconnect', namespace='/bar' - ) - c._trigger_event.assert_any_await( - '__disconnect_final', namespace='/bar' - ) + c._trigger_event.assert_any_await('disconnect', '/bar', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_await('__disconnect_final', '/bar') assert c.namespaces == {} assert not c.connected @@ -617,12 +609,9 @@ class TestAsyncClient: c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.AsyncMock() await c._handle_disconnect('/baz') - c._trigger_event.assert_any_await( - 'disconnect', namespace='/baz' - ) - c._trigger_event.assert_any_await( - '__disconnect_final', namespace='/baz' - ) + c._trigger_event.assert_any_await('disconnect', '/baz', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_await('__disconnect_final', '/baz') assert c.namespaces == {'/foo': '1', '/bar': '2'} assert c.connected @@ -632,8 +621,9 @@ class TestAsyncClient: c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.AsyncMock() await c._handle_disconnect('/') - c._trigger_event.assert_any_await('disconnect', namespace='/') - c._trigger_event.assert_any_await('__disconnect_final', namespace='/') + c._trigger_event.assert_any_await('disconnect', '/', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_await('__disconnect_final', '/') assert c.namespaces == {'/foo': '1', '/bar': '2'} assert c.connected @@ -818,6 +808,26 @@ class TestAsyncClient: handler.assert_awaited_once_with(1, '2') catchall_handler.assert_awaited_once_with('bar', 1, '2', 3) + async def test_trigger_legacy_disconnect_event(self): + c = async_client.AsyncClient() + + @c.on('disconnect') + def baz(): + return 'baz' + + r = await c._trigger_event('disconnect', '/', 'foo') + assert r == 'baz' + + async def test_trigger_legacy_disconnect_event_async(self): + c = async_client.AsyncClient() + + @c.on('disconnect') + async def baz(): + return 'baz' + + r = await c._trigger_event('disconnect', '/', 'foo') + assert r == 'baz' + async def test_trigger_event_class_namespace(self): c = async_client.AsyncClient() result = [] @@ -1127,10 +1137,8 @@ class TestAsyncClient: c.start_background_task = mock.MagicMock() c.sid = 'foo' c.eio.state = 'connected' - await c._handle_eio_disconnect() - c._trigger_event.assert_awaited_once_with( - 'disconnect', namespace='/' - ) + await c._handle_eio_disconnect('foo') + c._trigger_event.assert_awaited_once_with('disconnect', '/', 'foo') assert c.sid is None assert not c.connected @@ -1141,9 +1149,13 @@ class TestAsyncClient: c._trigger_event = mock.AsyncMock() c.sid = 'foo' c.eio.state = 'connected' - await c._handle_eio_disconnect() - c._trigger_event.assert_any_await('disconnect', namespace='/foo') - c._trigger_event.assert_any_await('disconnect', namespace='/bar') + await c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) + c._trigger_event.assert_any_await('disconnect', '/foo', + c.reason.CLIENT_DISCONNECT) + c._trigger_event.assert_any_await('disconnect', '/bar', + c.reason.CLIENT_DISCONNECT) + c._trigger_event.asserT_any_await('disconnect', '/', + c.reason.CLIENT_DISCONNECT) assert c.sid is None assert not c.connected @@ -1151,14 +1163,14 @@ class TestAsyncClient: c = async_client.AsyncClient(reconnection=True) c.start_background_task = mock.MagicMock() c.eio.state = 'connected' - await c._handle_eio_disconnect() + await c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) c.start_background_task.assert_called_once_with(c._handle_reconnect) async def test_eio_disconnect_self_disconnect(self): c = async_client.AsyncClient(reconnection=True) c.start_background_task = mock.MagicMock() c.eio.state = 'disconnected' - await c._handle_eio_disconnect() + await c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) c.start_background_task.assert_not_called() async def test_eio_disconnect_no_reconnect(self): @@ -1169,13 +1181,10 @@ class TestAsyncClient: c.start_background_task = mock.MagicMock() c.sid = 'foo' c.eio.state = 'connected' - await c._handle_eio_disconnect() - c._trigger_event.assert_any_await( - 'disconnect', namespace='/' - ) - c._trigger_event.assert_any_await( - '__disconnect_final', namespace='/' - ) + await c._handle_eio_disconnect(c.reason.TRANSPORT_ERROR) + c._trigger_event.assert_any_await('disconnect', '/', + c.reason.TRANSPORT_ERROR) + c._trigger_event.assert_any_await('__disconnect_final', '/') assert c.sid is None assert not c.connected c.start_background_task.assert_not_called() diff --git a/tests/async/test_manager.py b/tests/async/test_manager.py index fd5fe81..aa89064 100644 --- a/tests/async/test_manager.py +++ b/tests/async/test_manager.py @@ -183,8 +183,6 @@ class TestAsyncManager: await self.bm.enter_room(sid, '/foo', 'bar') await self.bm.enter_room(sid, '/foo', 'bar') await self.bm.close_room('bar', '/foo') - from pprint import pprint - pprint(self.bm.rooms) assert 'bar' not in self.bm.rooms['/foo'] async def test_close_invalid_room(self): diff --git a/tests/async/test_namespace.py b/tests/async/test_namespace.py index ad9b1a0..526d676 100644 --- a/tests/async/test_namespace.py +++ b/tests/async/test_namespace.py @@ -19,13 +19,37 @@ class TestAsyncNamespace: async def test_disconnect_event(self): result = {} + class MyNamespace(async_namespace.AsyncNamespace): + async def on_disconnect(self, sid, reason): + result['result'] = (sid, reason) + + ns = MyNamespace('/foo') + ns._set_server(mock.MagicMock()) + await ns.trigger_event('disconnect', 'sid', 'foo') + assert result['result'] == ('sid', 'foo') + + async def test_legacy_disconnect_event(self): + result = {} + + class MyNamespace(async_namespace.AsyncNamespace): + def on_disconnect(self, sid): + result['result'] = sid + + ns = MyNamespace('/foo') + ns._set_server(mock.MagicMock()) + await ns.trigger_event('disconnect', 'sid', 'foo') + assert result['result'] == 'sid' + + async def test_legacy_disconnect_event_async(self): + result = {} + class MyNamespace(async_namespace.AsyncNamespace): async def on_disconnect(self, sid): result['result'] = sid ns = MyNamespace('/foo') ns._set_server(mock.MagicMock()) - await ns.trigger_event('disconnect', 'sid') + await ns.trigger_event('disconnect', 'sid', 'foo') assert result['result'] == 'sid' async def test_sync_event(self): @@ -242,6 +266,42 @@ class TestAsyncNamespace: await ns.disconnect('sid', namespace='/bar') ns.server.disconnect.assert_awaited_with('sid', namespace='/bar') + async def test_disconnect_event_client(self): + result = {} + + class MyNamespace(async_namespace.AsyncClientNamespace): + async def on_disconnect(self, reason): + result['result'] = reason + + ns = MyNamespace('/foo') + ns._set_client(mock.MagicMock()) + await ns.trigger_event('disconnect', 'foo') + assert result['result'] == 'foo' + + async def test_legacy_disconnect_event_client(self): + result = {} + + class MyNamespace(async_namespace.AsyncClientNamespace): + def on_disconnect(self): + result['result'] = 'ok' + + ns = MyNamespace('/foo') + ns._set_client(mock.MagicMock()) + await ns.trigger_event('disconnect', 'foo') + assert result['result'] == 'ok' + + async def test_legacy_disconnect_event_client_async(self): + result = {} + + class MyNamespace(async_namespace.AsyncClientNamespace): + async def on_disconnect(self): + result['result'] = 'ok' + + ns = MyNamespace('/foo') + ns._set_client(mock.MagicMock()) + await ns.trigger_event('disconnect', 'foo') + assert result['result'] == 'ok' + async def test_sync_event_client(self): result = {} diff --git a/tests/async/test_server.py b/tests/async/test_server.py index d9129d4..f60de27 100644 --- a/tests/async/test_server.py +++ b/tests/async/test_server.py @@ -56,7 +56,7 @@ class TestAsyncServer: def foo(): pass - def bar(): + def bar(reason): pass s.on('disconnect', bar) @@ -537,8 +537,36 @@ class TestAsyncServer: s.on('disconnect', handler) await s._handle_eio_connect('123', 'environ') await s._handle_eio_message('123', '0') - await s._handle_eio_disconnect('123') - handler.assert_called_once_with('1') + await s._handle_eio_disconnect('123', 'foo') + handler.assert_called_once_with('1', 'foo') + s.manager.disconnect.assert_awaited_once_with( + '1', '/', ignore_queue=True) + assert s.environ == {} + + async def test_handle_legacy_disconnect(self, eio): + eio.return_value.send = mock.AsyncMock() + s = async_server.AsyncServer() + s.manager.disconnect = mock.AsyncMock() + handler = mock.MagicMock(side_effect=[TypeError, None]) + s.on('disconnect', handler) + await s._handle_eio_connect('123', 'environ') + await s._handle_eio_message('123', '0') + await s._handle_eio_disconnect('123', 'foo') + handler.assert_called_with('1') + s.manager.disconnect.assert_awaited_once_with( + '1', '/', ignore_queue=True) + assert s.environ == {} + + async def test_handle_legacy_disconnect_async(self, eio): + eio.return_value.send = mock.AsyncMock() + s = async_server.AsyncServer() + s.manager.disconnect = mock.AsyncMock() + handler = mock.AsyncMock(side_effect=[TypeError, None]) + s.on('disconnect', handler) + await s._handle_eio_connect('123', 'environ') + await s._handle_eio_message('123', '0') + await s._handle_eio_disconnect('123', 'foo') + handler.assert_awaited_with('1') s.manager.disconnect.assert_awaited_once_with( '1', '/', ignore_queue=True) assert s.environ == {} @@ -552,9 +580,9 @@ class TestAsyncServer: s.on('disconnect', handler_namespace, namespace='/foo') await s._handle_eio_connect('123', 'environ') await s._handle_eio_message('123', '0/foo,') - await s._handle_eio_disconnect('123') + await s._handle_eio_disconnect('123', 'foo') handler.assert_not_called() - handler_namespace.assert_called_once_with('1') + handler_namespace.assert_called_once_with('1', 'foo') assert s.environ == {} async def test_handle_disconnect_only_namespace(self, eio): @@ -568,13 +596,14 @@ class TestAsyncServer: await s._handle_eio_message('123', '0/foo,') await s._handle_eio_message('123', '1/foo,') assert handler.call_count == 0 - handler_namespace.assert_called_once_with('1') + handler_namespace.assert_called_once_with( + '1', s.reason.CLIENT_DISCONNECT) assert s.environ == {'123': 'environ'} async def test_handle_disconnect_unknown_client(self, eio): mgr = self._get_mock_manager() s = async_server.AsyncServer(client_manager=mgr) - await s._handle_eio_disconnect('123') + await s._handle_eio_disconnect('123', 'foo') async def test_handle_event(self, eio): eio.return_value.send = mock.AsyncMock() @@ -624,7 +653,8 @@ class TestAsyncServer: await s._handle_eio_message('123', '2/bar,["msg","a","b"]') await s._handle_eio_message('123', '2/foo,["my message","a","b","c"]') await s._handle_eio_message('123', '2/bar,["my message","a","b","c"]') - await s._trigger_event('disconnect', '/bar', sid_bar) + await s._trigger_event('disconnect', '/bar', sid_bar, + s.reason.CLIENT_DISCONNECT) connect_star_handler.assert_called_once_with('/bar', sid_bar) msg_foo_handler.assert_called_once_with(sid_foo, 'a', 'b') msg_star_handler.assert_called_once_with('/bar', sid_bar, 'a', 'b') @@ -884,8 +914,8 @@ class TestAsyncServer: def on_connect(self, sid, environ): result['result'] = (sid, environ) - async def on_disconnect(self, sid): - result['result'] = ('disconnect', sid) + async def on_disconnect(self, sid, reason): + result['result'] = ('disconnect', sid, reason) async def on_foo(self, sid, data): result['result'] = (sid, data) @@ -908,7 +938,8 @@ class TestAsyncServer: await s._handle_eio_message('123', '2/foo,["baz","a","b"]') assert result['result'] == ('a', 'b') await s.disconnect('1', '/foo') - assert result['result'] == ('disconnect', '1') + assert result['result'] == ('disconnect', '1', + s.reason.SERVER_DISCONNECT) async def test_catchall_namespace_handler(self, eio): eio.return_value.send = mock.AsyncMock() diff --git a/tests/common/test_client.py b/tests/common/test_client.py index c7399dc..ac930c7 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -752,8 +752,9 @@ class TestClient: c.connected = True c._trigger_event = mock.MagicMock() c._handle_disconnect('/') - c._trigger_event.assert_any_call('disconnect', namespace='/') - c._trigger_event.assert_any_call('__disconnect_final', namespace='/') + c._trigger_event.assert_any_call('disconnect', '/', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_call('__disconnect_final', '/') assert not c.connected c._handle_disconnect('/') assert c._trigger_event.call_count == 2 @@ -764,21 +765,15 @@ class TestClient: c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.MagicMock() c._handle_disconnect('/foo') - c._trigger_event.assert_any_call( - 'disconnect', namespace='/foo' - ) - c._trigger_event.assert_any_call( - '__disconnect_final', namespace='/foo' - ) + c._trigger_event.assert_any_call('disconnect', '/foo', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_call('__disconnect_final', '/foo') assert c.namespaces == {'/bar': '2'} assert c.connected c._handle_disconnect('/bar') - c._trigger_event.assert_any_call( - 'disconnect', namespace='/bar' - ) - c._trigger_event.assert_any_call( - '__disconnect_final', namespace='/bar' - ) + c._trigger_event.assert_any_call('disconnect', '/bar', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_call('__disconnect_final', '/bar') assert c.namespaces == {} assert not c.connected @@ -788,12 +783,9 @@ class TestClient: c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.MagicMock() c._handle_disconnect('/baz') - c._trigger_event.assert_any_call( - 'disconnect', namespace='/baz' - ) - c._trigger_event.assert_any_call( - '__disconnect_final', namespace='/baz' - ) + c._trigger_event.assert_any_call('disconnect', '/baz', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_call('__disconnect_final', '/baz') assert c.namespaces == {'/foo': '1', '/bar': '2'} assert c.connected @@ -803,9 +795,9 @@ class TestClient: c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.MagicMock() c._handle_disconnect('/') - print(c._trigger_event.call_args_list) - c._trigger_event.assert_any_call('disconnect', namespace='/') - c._trigger_event.assert_any_call('__disconnect_final', namespace='/') + c._trigger_event.assert_any_call('disconnect', '/', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_call('__disconnect_final', '/') assert c.namespaces == {'/foo': '1', '/bar': '2'} assert c.connected @@ -1003,8 +995,8 @@ class TestClient: def on_connect(self, ns): result['result'] = (ns,) - def on_disconnect(self, ns): - result['result'] = ('disconnect', ns) + def on_disconnect(self, ns, reason): + result['result'] = ('disconnect', ns, reason) def on_foo(self, ns, data): result['result'] = (ns, data) @@ -1025,8 +1017,8 @@ class TestClient: assert result['result'] == 'bar/foo' c._trigger_event('baz', '/foo', 'a', 'b') assert result['result'] == ('/foo', 'a', 'b') - c._trigger_event('disconnect', '/foo') - assert result['result'] == ('disconnect', '/foo') + c._trigger_event('disconnect', '/foo', 'bar') + assert result['result'] == ('disconnect', '/foo', 'bar') def test_trigger_event_class_namespace(self): c = client.Client() @@ -1286,8 +1278,8 @@ class TestClient: c.start_background_task = mock.MagicMock() c.sid = 'foo' c.eio.state = 'connected' - c._handle_eio_disconnect() - c._trigger_event.assert_called_once_with('disconnect', namespace='/') + c._handle_eio_disconnect('foo') + c._trigger_event.assert_called_once_with('disconnect', '/', 'foo') assert c.sid is None assert not c.connected @@ -1299,10 +1291,13 @@ class TestClient: c.start_background_task = mock.MagicMock() c.sid = 'foo' c.eio.state = 'connected' - c._handle_eio_disconnect() - c._trigger_event.assert_any_call('disconnect', namespace='/foo') - c._trigger_event.assert_any_call('disconnect', namespace='/bar') - c._trigger_event.assert_any_call('disconnect', namespace='/') + c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) + c._trigger_event.assert_any_call('disconnect', '/foo', + c.reason.CLIENT_DISCONNECT) + c._trigger_event.assert_any_call('disconnect', '/bar', + c.reason.CLIENT_DISCONNECT) + c._trigger_event.assert_any_call('disconnect', '/', + c.reason.CLIENT_DISCONNECT) assert c.sid is None assert not c.connected @@ -1310,14 +1305,14 @@ class TestClient: c = client.Client(reconnection=True) c.start_background_task = mock.MagicMock() c.eio.state = 'connected' - c._handle_eio_disconnect() + c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) c.start_background_task.assert_called_once_with(c._handle_reconnect) def test_eio_disconnect_self_disconnect(self): c = client.Client(reconnection=True) c.start_background_task = mock.MagicMock() c.eio.state = 'disconnected' - c._handle_eio_disconnect() + c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) c.start_background_task.assert_not_called() def test_eio_disconnect_no_reconnect(self): @@ -1328,9 +1323,10 @@ class TestClient: c.start_background_task = mock.MagicMock() c.sid = 'foo' c.eio.state = 'connected' - c._handle_eio_disconnect() - c._trigger_event.assert_any_call('disconnect', namespace='/') - c._trigger_event.assert_any_call('__disconnect_final', namespace='/') + c._handle_eio_disconnect(c.reason.TRANSPORT_ERROR) + c._trigger_event.assert_any_call('disconnect', '/', + c.reason.TRANSPORT_ERROR) + c._trigger_event.assert_any_call('__disconnect_final', '/') assert c.sid is None assert not c.connected c.start_background_task.assert_not_called() diff --git a/tests/common/test_namespace.py b/tests/common/test_namespace.py index 8bfa989..f1476e4 100644 --- a/tests/common/test_namespace.py +++ b/tests/common/test_namespace.py @@ -19,13 +19,25 @@ class TestNamespace: def test_disconnect_event(self): result = {} + class MyNamespace(namespace.Namespace): + def on_disconnect(self, sid, reason): + result['result'] = (sid, reason) + + ns = MyNamespace('/foo') + ns._set_server(mock.MagicMock()) + ns.trigger_event('disconnect', 'sid', 'foo') + assert result['result'] == ('sid', 'foo') + + def test_legacy_disconnect_event(self): + result = {} + class MyNamespace(namespace.Namespace): def on_disconnect(self, sid): result['result'] = sid ns = MyNamespace('/foo') ns._set_server(mock.MagicMock()) - ns.trigger_event('disconnect', 'sid') + ns.trigger_event('disconnect', 'sid', 'foo') assert result['result'] == 'sid' def test_event(self): @@ -216,6 +228,30 @@ class TestNamespace: ns.disconnect('sid', namespace='/bar') ns.server.disconnect.assert_called_with('sid', namespace='/bar') + def test_disconnect_event_client(self): + result = {} + + class MyNamespace(namespace.ClientNamespace): + def on_disconnect(self, reason): + result['result'] = reason + + ns = MyNamespace('/foo') + ns._set_client(mock.MagicMock()) + ns.trigger_event('disconnect', 'foo') + assert result['result'] == 'foo' + + def test_legacy_disconnect_event_client(self): + result = {} + + class MyNamespace(namespace.ClientNamespace): + def on_disconnect(self): + result['result'] = 'ok' + + ns = MyNamespace('/foo') + ns._set_client(mock.MagicMock()) + ns.trigger_event('disconnect', 'foo') + assert result['result'] == 'ok' + def test_event_not_found_client(self): result = {} diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 57ddc2f..445d5d9 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -39,7 +39,7 @@ class TestServer: def foo(): pass - def bar(): + def bar(reason): pass s.on('disconnect', bar) @@ -510,8 +510,21 @@ class TestServer: s.on('disconnect', handler) s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0') - s._handle_eio_disconnect('123') - handler.assert_called_once_with('1') + s._handle_eio_disconnect('123', 'foo') + handler.assert_called_once_with('1', 'foo') + s.manager.disconnect.assert_called_once_with('1', '/', + ignore_queue=True) + assert s.environ == {} + + def test_handle_legacy_disconnect(self, eio): + s = server.Server() + s.manager.disconnect = mock.MagicMock() + handler = mock.MagicMock(side_effect=[TypeError, None]) + s.on('disconnect', handler) + s._handle_eio_connect('123', 'environ') + s._handle_eio_message('123', '0') + s._handle_eio_disconnect('123', 'foo') + handler.assert_called_with('1') s.manager.disconnect.assert_called_once_with('1', '/', ignore_queue=True) assert s.environ == {} @@ -524,9 +537,9 @@ class TestServer: s.on('disconnect', handler_namespace, namespace='/foo') s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0/foo,') - s._handle_eio_disconnect('123') + s._handle_eio_disconnect('123', 'foo') handler.assert_not_called() - handler_namespace.assert_called_once_with('1') + handler_namespace.assert_called_once_with('1', 'foo') assert s.environ == {} def test_handle_disconnect_only_namespace(self, eio): @@ -539,13 +552,14 @@ class TestServer: s._handle_eio_message('123', '0/foo,') s._handle_eio_message('123', '1/foo,') assert handler.call_count == 0 - handler_namespace.assert_called_once_with('1') + handler_namespace.assert_called_once_with( + '1', s.reason.CLIENT_DISCONNECT) assert s.environ == {'123': 'environ'} def test_handle_disconnect_unknown_client(self, eio): mgr = mock.MagicMock() s = server.Server(client_manager=mgr) - s._handle_eio_disconnect('123') + s._handle_eio_disconnect('123', 'foo') def test_handle_event(self, eio): s = server.Server(async_handlers=False) @@ -596,7 +610,8 @@ class TestServer: s._handle_eio_message('123', '2/bar,["msg","a","b"]') s._handle_eio_message('123', '2/foo,["my message","a","b","c"]') s._handle_eio_message('123', '2/bar,["my message","a","b","c"]') - s._trigger_event('disconnect', '/bar', sid_bar) + s._trigger_event('disconnect', '/bar', sid_bar, + s.reason.CLIENT_DISCONNECT) connect_star_handler.assert_called_once_with('/bar', sid_bar) msg_foo_handler.assert_called_once_with(sid_foo, 'a', 'b') msg_star_handler.assert_called_once_with('/bar', sid_bar, 'a', 'b') @@ -825,8 +840,8 @@ class TestServer: def on_connect(self, sid, environ): result['result'] = (sid, environ) - def on_disconnect(self, sid): - result['result'] = ('disconnect', sid) + def on_disconnect(self, sid, reason): + result['result'] = ('disconnect', sid, reason) def on_foo(self, sid, data): result['result'] = (sid, data) @@ -849,7 +864,8 @@ class TestServer: s._handle_eio_message('123', '2/foo,["baz","a","b"]') assert result['result'] == ('a', 'b') s.disconnect('1', '/foo') - assert result['result'] == ('disconnect', '1') + assert result['result'] == ('disconnect', '1', + s.reason.SERVER_DISCONNECT) def test_catchall_namespace_handler(self, eio): result = {}