From a6e6ca6265c6bdf37feba35ca4a293b9f5efef40 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Sun, 15 Dec 2024 20:24:28 +0000 Subject: [PATCH] disconnect reasons for client --- src/socketio/async_client.py | 32 +++++++++--- src/socketio/async_namespace.py | 30 +++++++++-- src/socketio/base_client.py | 2 +- src/socketio/client.py | 20 +++++--- src/socketio/namespace.py | 9 +++- tests/async/test_client.py | 89 ++++++++++++++++++--------------- tests/async/test_namespace.py | 62 ++++++++++++++++++++++- tests/common/test_client.py | 71 +++++++++++++------------- tests/common/test_namespace.py | 39 ++++++++++++++- 9 files changed, 256 insertions(+), 98 deletions(-) diff --git a/src/socketio/async_client.py b/src/socketio/async_client.py index fb1abc1..c3a42bd 100644 --- a/src/socketio/async_client.py +++ b/src/socketio/async_client.py @@ -398,8 +398,9 @@ class AsyncClient(base_client.BaseClient): if not self.connected: return namespace = namespace or '/' - await self._trigger_event('disconnect', namespace=namespace) - await self._trigger_event('__disconnect_final', namespace=namespace) + await self._trigger_event('disconnect', namespace, + self.reason.SERVER_DISCONNECT) + await self._trigger_event('__disconnect_final', namespace) if namespace in self.namespaces: del self.namespaces[namespace] if not self.namespaces: @@ -462,11 +463,27 @@ class AsyncClient(base_client.BaseClient): if handler: if asyncio.iscoroutinefunction(handler): try: - ret = await handler(*args) + try: + ret = await handler(*args) + except TypeError: + # the legacy disconnect event does not take a reason + # argument + if event == 'disconnect': + ret = await handler(*args[:-1]) + else: # pragma: no cover + raise except asyncio.CancelledError: # pragma: no cover ret = None else: - ret = handler(*args) + try: + ret = handler(*args) + except TypeError: + # the legacy disconnect event does not take a reason + # argument + if event == 'disconnect': + ret = handler(*args[:-1]) + else: # pragma: no cover + raise return ret # or else, forward the event to a namepsace handler if one exists @@ -566,16 +583,15 @@ class AsyncClient(base_client.BaseClient): else: raise ValueError('Unknown packet type.') - async def _handle_eio_disconnect(self): + async def _handle_eio_disconnect(self, reason): """Handle the Engine.IO disconnection event.""" self.logger.info('Engine.IO connection dropped') will_reconnect = self.reconnection and self.eio.state == 'connected' if self.connected: for n in self.namespaces: - await self._trigger_event('disconnect', namespace=n) + await self._trigger_event('disconnect', n, reason) if not will_reconnect: - await self._trigger_event('__disconnect_final', - namespace=n) + await self._trigger_event('__disconnect_final', n) self.namespaces = {} self.connected = False self.callbacks = {} diff --git a/src/socketio/async_namespace.py b/src/socketio/async_namespace.py index 48676e5..42d6508 100644 --- a/src/socketio/async_namespace.py +++ b/src/socketio/async_namespace.py @@ -46,7 +46,15 @@ class AsyncNamespace(base_namespace.BaseServerNamespace): except asyncio.CancelledError: # pragma: no cover ret = None else: - ret = handler(*args) + try: + ret = handler(*args) + except TypeError: + # legacy disconnect events do not have a reason + # argument + if event == 'disconnect': + ret = handler(*args[:-1]) + else: # pragma: no cover + raise return ret async def emit(self, event, data=None, to=None, room=None, skip_sid=None, @@ -207,11 +215,27 @@ class AsyncClientNamespace(base_namespace.BaseClientNamespace): handler = getattr(self, handler_name) if asyncio.iscoroutinefunction(handler) is True: try: - ret = await handler(*args) + try: + ret = await handler(*args) + except TypeError: + # legacy disconnect events do not have a reason + # argument + if event == 'disconnect': + ret = await handler(*args[:-1]) + else: # pragma: no cover + raise except asyncio.CancelledError: # pragma: no cover ret = None else: - ret = handler(*args) + try: + ret = handler(*args) + except TypeError: + # legacy disconnect events do not have a reason + # argument + if event == 'disconnect': + ret = handler(*args[:-1]) + else: # pragma: no cover + raise return ret async def emit(self, event, data=None, namespace=None, callback=None): diff --git a/src/socketio/base_client.py b/src/socketio/base_client.py index d620a3b..ef89c11 100644 --- a/src/socketio/base_client.py +++ b/src/socketio/base_client.py @@ -289,7 +289,7 @@ class BaseClient: def _handle_eio_message(self, data): # pragma: no cover raise NotImplementedError() - def _handle_eio_disconnect(self): # pragma: no cover + def _handle_eio_disconnect(self, reason): # pragma: no cover raise NotImplementedError() def _engineio_client_class(self): # pragma: no cover diff --git a/src/socketio/client.py b/src/socketio/client.py index c4f9eaa..ade2dd6 100644 --- a/src/socketio/client.py +++ b/src/socketio/client.py @@ -377,8 +377,9 @@ class Client(base_client.BaseClient): if not self.connected: return namespace = namespace or '/' - self._trigger_event('disconnect', namespace=namespace) - self._trigger_event('__disconnect_final', namespace=namespace) + self._trigger_event('disconnect', namespace, + self.reason.SERVER_DISCONNECT) + self._trigger_event('__disconnect_final', namespace) if namespace in self.namespaces: del self.namespaces[namespace] if not self.namespaces: @@ -436,7 +437,14 @@ class Client(base_client.BaseClient): # first see if we have an explicit handler for the event handler, args = self._get_event_handler(event, namespace, args) if handler: - return handler(*args) + try: + return handler(*args) + except TypeError: + # the legacy disconnect event does not take a reason argument + if event == 'disconnect': + return handler(*args[:-1]) + else: # pragma: no cover + raise # or else, forward the event to a namespace handler if one exists handler, args = self._get_namespace_handler(namespace, args) @@ -525,15 +533,15 @@ class Client(base_client.BaseClient): else: raise ValueError('Unknown packet type.') - def _handle_eio_disconnect(self): + def _handle_eio_disconnect(self, reason): """Handle the Engine.IO disconnection event.""" self.logger.info('Engine.IO connection dropped') will_reconnect = self.reconnection and self.eio.state == 'connected' if self.connected: for n in self.namespaces: - self._trigger_event('disconnect', namespace=n) + self._trigger_event('disconnect', n, reason) if not will_reconnect: - self._trigger_event('__disconnect_final', namespace=n) + self._trigger_event('__disconnect_final', n) self.namespaces = {} self.connected = False self.callbacks = {} diff --git a/src/socketio/namespace.py b/src/socketio/namespace.py index 25b32b9..60cab78 100644 --- a/src/socketio/namespace.py +++ b/src/socketio/namespace.py @@ -161,7 +161,14 @@ class ClientNamespace(base_namespace.BaseClientNamespace): """ handler_name = 'on_' + (event or '') if hasattr(self, handler_name): - return getattr(self, handler_name)(*args) + try: + return getattr(self, handler_name)(*args) + except TypeError: + # legacy disconnect events do not have a reason argument + if event == 'disconnect': + return getattr(self, handler_name)(*args[:-1]) + else: # pragma: no cover + raise def emit(self, event, data=None, namespace=None, callback=None): """Emit a custom event to the server. diff --git a/tests/async/test_client.py b/tests/async/test_client.py index 26681b7..3eec1a8 100644 --- a/tests/async/test_client.py +++ b/tests/async/test_client.py @@ -578,11 +578,9 @@ class TestAsyncClient: c._trigger_event = mock.AsyncMock() await c._handle_disconnect('/') c._trigger_event.assert_any_await( - 'disconnect', namespace='/' - ) - c._trigger_event.assert_any_await( - '__disconnect_final', namespace='/' + 'disconnect', '/', c.reason.SERVER_DISCONNECT ) + c._trigger_event.assert_any_await('__disconnect_final', '/') assert not c.connected await c._handle_disconnect('/') assert c._trigger_event.await_count == 2 @@ -593,21 +591,15 @@ class TestAsyncClient: c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.AsyncMock() await c._handle_disconnect('/foo') - c._trigger_event.assert_any_await( - 'disconnect', namespace='/foo' - ) - c._trigger_event.assert_any_await( - '__disconnect_final', namespace='/foo' - ) + c._trigger_event.assert_any_await('disconnect', '/foo', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_await('__disconnect_final', '/foo') assert c.namespaces == {'/bar': '2'} assert c.connected await c._handle_disconnect('/bar') - c._trigger_event.assert_any_await( - 'disconnect', namespace='/bar' - ) - c._trigger_event.assert_any_await( - '__disconnect_final', namespace='/bar' - ) + c._trigger_event.assert_any_await('disconnect', '/bar', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_await('__disconnect_final', '/bar') assert c.namespaces == {} assert not c.connected @@ -617,12 +609,9 @@ class TestAsyncClient: c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.AsyncMock() await c._handle_disconnect('/baz') - c._trigger_event.assert_any_await( - 'disconnect', namespace='/baz' - ) - c._trigger_event.assert_any_await( - '__disconnect_final', namespace='/baz' - ) + c._trigger_event.assert_any_await('disconnect', '/baz', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_await('__disconnect_final', '/baz') assert c.namespaces == {'/foo': '1', '/bar': '2'} assert c.connected @@ -632,8 +621,9 @@ class TestAsyncClient: c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.AsyncMock() await c._handle_disconnect('/') - c._trigger_event.assert_any_await('disconnect', namespace='/') - c._trigger_event.assert_any_await('__disconnect_final', namespace='/') + c._trigger_event.assert_any_await('disconnect', '/', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_await('__disconnect_final', '/') assert c.namespaces == {'/foo': '1', '/bar': '2'} assert c.connected @@ -818,6 +808,26 @@ class TestAsyncClient: handler.assert_awaited_once_with(1, '2') catchall_handler.assert_awaited_once_with('bar', 1, '2', 3) + async def test_trigger_legacy_disconnect_event(self): + c = async_client.AsyncClient() + + @c.on('disconnect') + def baz(): + return 'baz' + + r = await c._trigger_event('disconnect', '/', 'foo') + assert r == 'baz' + + async def test_trigger_legacy_disconnect_event_async(self): + c = async_client.AsyncClient() + + @c.on('disconnect') + async def baz(): + return 'baz' + + r = await c._trigger_event('disconnect', '/', 'foo') + assert r == 'baz' + async def test_trigger_event_class_namespace(self): c = async_client.AsyncClient() result = [] @@ -1127,10 +1137,8 @@ class TestAsyncClient: c.start_background_task = mock.MagicMock() c.sid = 'foo' c.eio.state = 'connected' - await c._handle_eio_disconnect() - c._trigger_event.assert_awaited_once_with( - 'disconnect', namespace='/' - ) + await c._handle_eio_disconnect('foo') + c._trigger_event.assert_awaited_once_with('disconnect', '/', 'foo') assert c.sid is None assert not c.connected @@ -1141,9 +1149,13 @@ class TestAsyncClient: c._trigger_event = mock.AsyncMock() c.sid = 'foo' c.eio.state = 'connected' - await c._handle_eio_disconnect() - c._trigger_event.assert_any_await('disconnect', namespace='/foo') - c._trigger_event.assert_any_await('disconnect', namespace='/bar') + await c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) + c._trigger_event.assert_any_await('disconnect', '/foo', + c.reason.CLIENT_DISCONNECT) + c._trigger_event.assert_any_await('disconnect', '/bar', + c.reason.CLIENT_DISCONNECT) + c._trigger_event.asserT_any_await('disconnect', '/', + c.reason.CLIENT_DISCONNECT) assert c.sid is None assert not c.connected @@ -1151,14 +1163,14 @@ class TestAsyncClient: c = async_client.AsyncClient(reconnection=True) c.start_background_task = mock.MagicMock() c.eio.state = 'connected' - await c._handle_eio_disconnect() + await c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) c.start_background_task.assert_called_once_with(c._handle_reconnect) async def test_eio_disconnect_self_disconnect(self): c = async_client.AsyncClient(reconnection=True) c.start_background_task = mock.MagicMock() c.eio.state = 'disconnected' - await c._handle_eio_disconnect() + await c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) c.start_background_task.assert_not_called() async def test_eio_disconnect_no_reconnect(self): @@ -1169,13 +1181,10 @@ class TestAsyncClient: c.start_background_task = mock.MagicMock() c.sid = 'foo' c.eio.state = 'connected' - await c._handle_eio_disconnect() - c._trigger_event.assert_any_await( - 'disconnect', namespace='/' - ) - c._trigger_event.assert_any_await( - '__disconnect_final', namespace='/' - ) + await c._handle_eio_disconnect(c.reason.TRANSPORT_ERROR) + c._trigger_event.assert_any_await('disconnect', '/', + c.reason.TRANSPORT_ERROR) + c._trigger_event.assert_any_await('__disconnect_final', '/') assert c.sid is None assert not c.connected c.start_background_task.assert_not_called() diff --git a/tests/async/test_namespace.py b/tests/async/test_namespace.py index ad9b1a0..656cb47 100644 --- a/tests/async/test_namespace.py +++ b/tests/async/test_namespace.py @@ -19,13 +19,37 @@ class TestAsyncNamespace: async def test_disconnect_event(self): result = {} + class MyNamespace(async_namespace.AsyncNamespace): + async def on_disconnect(self, sid, reason): + result['result'] = (sid, reason) + + ns = MyNamespace('/foo') + ns._set_server(mock.MagicMock()) + _run(ns.trigger_event('disconnect', 'sid', 'foo')) + assert result['result'] == ('sid', 'foo') + + def test_legacy_disconnect_event(self): + result = {} + + class MyNamespace(async_namespace.AsyncNamespace): + def on_disconnect(self, sid): + result['result'] = sid + + ns = MyNamespace('/foo') + ns._set_server(mock.MagicMock()) + _run(ns.trigger_event('disconnect', 'sid', 'foo')) + assert result['result'] == 'sid' + + def test_legacy_disconnect_event_async(self): + result = {} + class MyNamespace(async_namespace.AsyncNamespace): async def on_disconnect(self, sid): result['result'] = sid ns = MyNamespace('/foo') ns._set_server(mock.MagicMock()) - await ns.trigger_event('disconnect', 'sid') + await ns.trigger_event('disconnect', 'sid', 'foo') assert result['result'] == 'sid' async def test_sync_event(self): @@ -242,6 +266,42 @@ class TestAsyncNamespace: await ns.disconnect('sid', namespace='/bar') ns.server.disconnect.assert_awaited_with('sid', namespace='/bar') + async def test_disconnect_event(self): + result = {} + + class MyNamespace(async_namespace.AsyncClientNamespace): + async def on_disconnect(self, reason): + result['result'] = reason + + ns = MyNamespace('/foo') + ns._set_client(mock.MagicMock()) + await ns.trigger_event('disconnect', 'foo') + assert result['result'] == 'foo' + + async def test_legacy_disconnect_event_client(self): + result = {} + + class MyNamespace(async_namespace.AsyncClientNamespace): + def on_disconnect(self): + result['result'] = 'ok' + + ns = MyNamespace('/foo') + ns._set_client(mock.MagicMock()) + await ns.trigger_event('disconnect', 'foo') + assert result['result'] == 'ok' + + async def test_legacy_disconnect_event_client_async(self): + result = {} + + class MyNamespace(async_namespace.AsyncClientNamespace): + async def on_disconnect(self): + result['result'] = 'ok' + + ns = MyNamespace('/foo') + ns._set_client(mock.MagicMock()) + await ns.trigger_event('disconnect', 'foo') + assert result['result'] == 'ok' + async def test_sync_event_client(self): result = {} diff --git a/tests/common/test_client.py b/tests/common/test_client.py index c7399dc..1ebcea7 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -752,8 +752,9 @@ class TestClient: c.connected = True c._trigger_event = mock.MagicMock() c._handle_disconnect('/') - c._trigger_event.assert_any_call('disconnect', namespace='/') - c._trigger_event.assert_any_call('__disconnect_final', namespace='/') + c._trigger_event.assert_any_call('disconnect', '/', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_call('__disconnect_final', '/') assert not c.connected c._handle_disconnect('/') assert c._trigger_event.call_count == 2 @@ -764,21 +765,15 @@ class TestClient: c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.MagicMock() c._handle_disconnect('/foo') - c._trigger_event.assert_any_call( - 'disconnect', namespace='/foo' - ) - c._trigger_event.assert_any_call( - '__disconnect_final', namespace='/foo' - ) + c._trigger_event.assert_any_call('disconnect', '/foo', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_call('__disconnect_final', '/foo') assert c.namespaces == {'/bar': '2'} assert c.connected c._handle_disconnect('/bar') - c._trigger_event.assert_any_call( - 'disconnect', namespace='/bar' - ) - c._trigger_event.assert_any_call( - '__disconnect_final', namespace='/bar' - ) + c._trigger_event.assert_any_call('disconnect', '/bar', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_call('__disconnect_final', '/bar') assert c.namespaces == {} assert not c.connected @@ -788,12 +783,9 @@ class TestClient: c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.MagicMock() c._handle_disconnect('/baz') - c._trigger_event.assert_any_call( - 'disconnect', namespace='/baz' - ) - c._trigger_event.assert_any_call( - '__disconnect_final', namespace='/baz' - ) + c._trigger_event.assert_any_call('disconnect', '/baz', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_call('__disconnect_final', '/baz') assert c.namespaces == {'/foo': '1', '/bar': '2'} assert c.connected @@ -804,8 +796,9 @@ class TestClient: c._trigger_event = mock.MagicMock() c._handle_disconnect('/') print(c._trigger_event.call_args_list) - c._trigger_event.assert_any_call('disconnect', namespace='/') - c._trigger_event.assert_any_call('__disconnect_final', namespace='/') + c._trigger_event.assert_any_call('disconnect', '/', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_call('__disconnect_final', '/') assert c.namespaces == {'/foo': '1', '/bar': '2'} assert c.connected @@ -1003,8 +996,8 @@ class TestClient: def on_connect(self, ns): result['result'] = (ns,) - def on_disconnect(self, ns): - result['result'] = ('disconnect', ns) + def on_disconnect(self, ns, reason): + result['result'] = ('disconnect', ns, reason) def on_foo(self, ns, data): result['result'] = (ns, data) @@ -1025,8 +1018,8 @@ class TestClient: assert result['result'] == 'bar/foo' c._trigger_event('baz', '/foo', 'a', 'b') assert result['result'] == ('/foo', 'a', 'b') - c._trigger_event('disconnect', '/foo') - assert result['result'] == ('disconnect', '/foo') + c._trigger_event('disconnect', '/foo', 'bar') + assert result['result'] == ('disconnect', '/foo', 'bar') def test_trigger_event_class_namespace(self): c = client.Client() @@ -1286,8 +1279,8 @@ class TestClient: c.start_background_task = mock.MagicMock() c.sid = 'foo' c.eio.state = 'connected' - c._handle_eio_disconnect() - c._trigger_event.assert_called_once_with('disconnect', namespace='/') + c._handle_eio_disconnect('foo') + c._trigger_event.assert_called_once_with('disconnect', '/', 'foo') assert c.sid is None assert not c.connected @@ -1299,10 +1292,13 @@ class TestClient: c.start_background_task = mock.MagicMock() c.sid = 'foo' c.eio.state = 'connected' - c._handle_eio_disconnect() - c._trigger_event.assert_any_call('disconnect', namespace='/foo') - c._trigger_event.assert_any_call('disconnect', namespace='/bar') - c._trigger_event.assert_any_call('disconnect', namespace='/') + c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) + c._trigger_event.assert_any_call('disconnect', '/foo', + c.reason.CLIENT_DISCONNECT) + c._trigger_event.assert_any_call('disconnect', '/bar', + c.reason.CLIENT_DISCONNECT) + c._trigger_event.assert_any_call('disconnect', '/', + c.reason.CLIENT_DISCONNECT) assert c.sid is None assert not c.connected @@ -1310,14 +1306,14 @@ class TestClient: c = client.Client(reconnection=True) c.start_background_task = mock.MagicMock() c.eio.state = 'connected' - c._handle_eio_disconnect() + c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) c.start_background_task.assert_called_once_with(c._handle_reconnect) def test_eio_disconnect_self_disconnect(self): c = client.Client(reconnection=True) c.start_background_task = mock.MagicMock() c.eio.state = 'disconnected' - c._handle_eio_disconnect() + c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) c.start_background_task.assert_not_called() def test_eio_disconnect_no_reconnect(self): @@ -1328,9 +1324,10 @@ class TestClient: c.start_background_task = mock.MagicMock() c.sid = 'foo' c.eio.state = 'connected' - c._handle_eio_disconnect() - c._trigger_event.assert_any_call('disconnect', namespace='/') - c._trigger_event.assert_any_call('__disconnect_final', namespace='/') + c._handle_eio_disconnect(c.reason.TRANSPORT_ERROR) + c._trigger_event.assert_any_call('disconnect', '/', + c.reason.TRANSPORT_ERROR) + c._trigger_event.assert_any_call('__disconnect_final', '/') assert c.sid is None assert not c.connected c.start_background_task.assert_not_called() diff --git a/tests/common/test_namespace.py b/tests/common/test_namespace.py index 8bfa989..be6bccb 100644 --- a/tests/common/test_namespace.py +++ b/tests/common/test_namespace.py @@ -19,13 +19,25 @@ class TestNamespace: def test_disconnect_event(self): result = {} + class MyNamespace(namespace.Namespace): + def on_disconnect(self, sid, reason): + result['result'] = (sid, reason) + + ns = MyNamespace('/foo') + ns._set_server(mock.MagicMock()) + ns.trigger_event('disconnect', 'sid', 'foo') + assert result['result'] == ('sid', 'foo') + + def test_legacy_disconnect_event(self): + result = {} + class MyNamespace(namespace.Namespace): def on_disconnect(self, sid): result['result'] = sid ns = MyNamespace('/foo') ns._set_server(mock.MagicMock()) - ns.trigger_event('disconnect', 'sid') + ns.trigger_event('disconnect', 'sid', 'foo') assert result['result'] == 'sid' def test_event(self): @@ -216,6 +228,31 @@ class TestNamespace: ns.disconnect('sid', namespace='/bar') ns.server.disconnect.assert_called_with('sid', namespace='/bar') + def test_disconnect_event_client(self): + result = {} + + class MyNamespace(namespace.ClientNamespace): + def on_disconnect(self, reason): + result['result'] = reason + + ns = MyNamespace('/foo') + ns._set_client(mock.MagicMock()) + ns.trigger_event('disconnect', 'foo') + assert result['result'] == 'foo' + + def test_legacy_disconnect_event_client(self): + result = {} + + class MyNamespace(namespace.ClientNamespace): + def on_disconnect(self): + result['result'] = 'ok' + + ns = MyNamespace('/foo') + ns._set_client(mock.MagicMock()) + ns.trigger_event('disconnect', 'foo') + assert result['result'] == 'ok' + + def test_event_not_found_client(self): result = {}