From db691a939772f45aede52000a190804826c6fea6 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Sun, 15 Dec 2024 18:55:57 +0000 Subject: [PATCH] disconnect reasons for server --- src/socketio/async_namespace.py | 10 ++++++- src/socketio/async_server.py | 33 ++++++++++++++++----- src/socketio/base_client.py | 4 +++ src/socketio/base_server.py | 3 ++ src/socketio/namespace.py | 9 +++++- src/socketio/server.py | 24 ++++++++++----- tests/async/test_server.py | 52 ++++++++++++++++++++++++++------- tests/common/test_server.py | 37 ++++++++++++++++------- 8 files changed, 133 insertions(+), 39 deletions(-) diff --git a/src/socketio/async_namespace.py b/src/socketio/async_namespace.py index 89442ae..48676e5 100644 --- a/src/socketio/async_namespace.py +++ b/src/socketio/async_namespace.py @@ -34,7 +34,15 @@ class AsyncNamespace(base_namespace.BaseServerNamespace): handler = getattr(self, handler_name) if asyncio.iscoroutinefunction(handler) is True: try: - ret = await handler(*args) + try: + ret = await handler(*args) + except TypeError: + # legacy disconnect events do not have a reason + # argument + if event == 'disconnect': + ret = await handler(*args[:-1]) + else: # pragma: no cover + raise except asyncio.CancelledError: # pragma: no cover ret = None else: diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index 9b0e977..5f51844 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -427,7 +427,8 @@ class AsyncServer(base_server.BaseServer): eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) await self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, namespace=namespace)) - await self._trigger_event('disconnect', namespace, sid) + await self._trigger_event('disconnect', namespace, sid, + self.reason.SERVER_DISCONNECT) await self.manager.disconnect(sid, namespace=namespace, ignore_queue=True) @@ -575,14 +576,15 @@ class AsyncServer(base_server.BaseServer): await self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) - async def _handle_disconnect(self, eio_sid, namespace): + async def _handle_disconnect(self, eio_sid, namespace, reason=None): """Handle a client disconnect.""" namespace = namespace or '/' sid = self.manager.sid_from_eio_sid(eio_sid, namespace) if not self.manager.is_connected(sid, namespace): # pragma: no cover return self.manager.pre_disconnect(sid, namespace=namespace) - await self._trigger_event('disconnect', namespace, sid) + await self._trigger_event('disconnect', namespace, sid, + reason or self.reason.UNKNOWN) await self.manager.disconnect(sid, namespace, ignore_queue=True) async def _handle_event(self, eio_sid, namespace, id, data): @@ -634,11 +636,25 @@ class AsyncServer(base_server.BaseServer): if handler: if asyncio.iscoroutinefunction(handler): try: - ret = await handler(*args) + try: + ret = await handler(*args) + except TypeError: + # legacy disconnect events use only one argument + if event == 'disconnect': + ret = await handler(*args[:-1]) + else: # pragma: no cover + raise except asyncio.CancelledError: # pragma: no cover ret = None else: - ret = handler(*args) + try: + ret = handler(*args) + except TypeError: + # legacy disconnect events use only one argument + if event == 'disconnect': + ret = handler(*args[:-1]) + else: # pragma: no cover + raise return ret # or else, forward the event to a namespace handler if one exists handler, args = self._get_namespace_handler(namespace, args) @@ -671,7 +687,8 @@ class AsyncServer(base_server.BaseServer): if pkt.packet_type == packet.CONNECT: await self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: - await self._handle_disconnect(eio_sid, pkt.namespace) + await self._handle_disconnect(eio_sid, pkt.namespace, + self.reason.CLIENT_DISCONNECT) elif pkt.packet_type == packet.EVENT: await self._handle_event(eio_sid, pkt.namespace, pkt.id, pkt.data) @@ -686,10 +703,10 @@ class AsyncServer(base_server.BaseServer): else: raise ValueError('Unknown packet type.') - async def _handle_eio_disconnect(self, eio_sid): + async def _handle_eio_disconnect(self, eio_sid, reason): """Handle Engine.IO disconnect event.""" for n in list(self.manager.get_namespaces()).copy(): - await self._handle_disconnect(eio_sid, n) + await self._handle_disconnect(eio_sid, n, reason) if eio_sid in self.environ: del self.environ[eio_sid] diff --git a/src/socketio/base_client.py b/src/socketio/base_client.py index 1becf91..d620a3b 100644 --- a/src/socketio/base_client.py +++ b/src/socketio/base_client.py @@ -3,6 +3,8 @@ import logging import signal import threading +import engineio + from . import base_namespace from . import packet @@ -31,6 +33,8 @@ original_signal_handler = None class BaseClient: reserved_events = ['connect', 'connect_error', 'disconnect', '__disconnect_final'] + print(dir(engineio.Client)) + reason = engineio.Client.reason def __init__(self, reconnection=True, reconnection_attempts=0, reconnection_delay=1, reconnection_delay_max=5, diff --git a/src/socketio/base_server.py b/src/socketio/base_server.py index d5a353b..d134eba 100644 --- a/src/socketio/base_server.py +++ b/src/socketio/base_server.py @@ -1,5 +1,7 @@ import logging +import engineio + from . import manager from . import base_namespace from . import packet @@ -9,6 +11,7 @@ default_logger = logging.getLogger('socketio.server') class BaseServer: reserved_events = ['connect', 'disconnect'] + reason = engineio.Server.reason def __init__(self, client_manager=None, logger=False, serializer='default', json=None, async_handlers=True, always_connect=False, diff --git a/src/socketio/namespace.py b/src/socketio/namespace.py index 3bf4f95..25b32b9 100644 --- a/src/socketio/namespace.py +++ b/src/socketio/namespace.py @@ -23,7 +23,14 @@ class Namespace(base_namespace.BaseServerNamespace): """ handler_name = 'on_' + (event or '') if hasattr(self, handler_name): - return getattr(self, handler_name)(*args) + try: + return getattr(self, handler_name)(*args) + except TypeError: + # legacy disconnect events do not have a reason argument + if event == 'disconnect': + return getattr(self, handler_name)(*args[:-1]) + else: # pragma: no cover + raise def emit(self, event, data=None, to=None, room=None, skip_sid=None, namespace=None, callback=None, ignore_queue=False): diff --git a/src/socketio/server.py b/src/socketio/server.py index ae73df6..21a6de4 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -403,7 +403,8 @@ class Server(base_server.BaseServer): eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, namespace=namespace)) - self._trigger_event('disconnect', namespace, sid) + self._trigger_event('disconnect', namespace, sid, + self.reason.SERVER_DISCONNECT) self.manager.disconnect(sid, namespace=namespace, ignore_queue=True) @@ -557,14 +558,15 @@ class Server(base_server.BaseServer): self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) - def _handle_disconnect(self, eio_sid, namespace): + def _handle_disconnect(self, eio_sid, namespace, reason=None): """Handle a client disconnect.""" namespace = namespace or '/' sid = self.manager.sid_from_eio_sid(eio_sid, namespace) if not self.manager.is_connected(sid, namespace): # pragma: no cover return self.manager.pre_disconnect(sid, namespace=namespace) - self._trigger_event('disconnect', namespace, sid) + self._trigger_event('disconnect', namespace, sid, + reason or self.reason.UNKNOWN) self.manager.disconnect(sid, namespace, ignore_queue=True) def _handle_event(self, eio_sid, namespace, id, data): @@ -611,7 +613,14 @@ class Server(base_server.BaseServer): # first see if we have an explicit handler for the event handler, args = self._get_event_handler(event, namespace, args) if handler: - return handler(*args) + try: + return handler(*args) + except TypeError: + # legacy disconnect events use only one argument + if event == 'disconnect': + return handler(*args[:-1]) + else: # pragma: no cover + raise # or else, forward the event to a namespace handler if one exists handler, args = self._get_namespace_handler(namespace, args) if handler: @@ -642,7 +651,8 @@ class Server(base_server.BaseServer): if pkt.packet_type == packet.CONNECT: self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: - self._handle_disconnect(eio_sid, pkt.namespace) + self._handle_disconnect(eio_sid, pkt.namespace, + self.reason.CLIENT_DISCONNECT) elif pkt.packet_type == packet.EVENT: self._handle_event(eio_sid, pkt.namespace, pkt.id, pkt.data) elif pkt.packet_type == packet.ACK: @@ -655,10 +665,10 @@ class Server(base_server.BaseServer): else: raise ValueError('Unknown packet type.') - def _handle_eio_disconnect(self, eio_sid): + def _handle_eio_disconnect(self, eio_sid, reason): """Handle Engine.IO disconnect event.""" for n in list(self.manager.get_namespaces()).copy(): - self._handle_disconnect(eio_sid, n) + self._handle_disconnect(eio_sid, n, reason) if eio_sid in self.environ: del self.environ[eio_sid] diff --git a/tests/async/test_server.py b/tests/async/test_server.py index d9129d4..34b6c38 100644 --- a/tests/async/test_server.py +++ b/tests/async/test_server.py @@ -56,7 +56,7 @@ class TestAsyncServer: def foo(): pass - def bar(): + def bar(reason): pass s.on('disconnect', bar) @@ -537,8 +537,36 @@ class TestAsyncServer: s.on('disconnect', handler) await s._handle_eio_connect('123', 'environ') await s._handle_eio_message('123', '0') - await s._handle_eio_disconnect('123') - handler.assert_called_once_with('1') + await s._handle_eio_disconnect('123', 'foo') + handler.assert_called_once_with('1', 'foo') + s.manager.disconnect.assert_awaited_once_with( + '1', '/', ignore_queue=True) + assert s.environ == {} + + async def test_handle_legacy_disconnect(self, eio): + eio.return_value.send = mock.AsyncMock() + s = async_server.AsyncServer() + s.manager.disconnect = mock.AsyncMock() + handler = mock.MagicMock(side_effect=[TypeError, None]) + s.on('disconnect', handler) + await s._handle_eio_connect('123', 'environ') + await s._handle_eio_message('123', '0') + await s._handle_eio_disconnect('123', 'foo') + handler.assert_called_with('1') + s.manager.disconnect.assert_awaited_once_with( + '1', '/', ignore_queue=True) + assert s.environ == {} + + async def test_handle_legacy_disconnect_async(self, eio): + eio.return_value.send = mock.AsyncMock() + s = async_server.AsyncServer() + s.manager.disconnect = mock.AsyncMock() + handler = mock.AsyncMock(side_effect=[TypeError, None]) + s.on('disconnect', handler) + await s._handle_eio_connect('123', 'environ') + await s._handle_eio_message('123', '0') + await s._handle_eio_disconnect('123', 'foo') + handler.assert_awaited_with('1') s.manager.disconnect.assert_awaited_once_with( '1', '/', ignore_queue=True) assert s.environ == {} @@ -552,9 +580,9 @@ class TestAsyncServer: s.on('disconnect', handler_namespace, namespace='/foo') await s._handle_eio_connect('123', 'environ') await s._handle_eio_message('123', '0/foo,') - await s._handle_eio_disconnect('123') + await s._handle_eio_disconnect('123', 'foo') handler.assert_not_called() - handler_namespace.assert_called_once_with('1') + handler_namespace.assert_called_once_with('1', 'foo') assert s.environ == {} async def test_handle_disconnect_only_namespace(self, eio): @@ -568,13 +596,14 @@ class TestAsyncServer: await s._handle_eio_message('123', '0/foo,') await s._handle_eio_message('123', '1/foo,') assert handler.call_count == 0 - handler_namespace.assert_called_once_with('1') + handler_namespace.assert_called_once_with( + '1', s.reason.CLIENT_DISCONNECT) assert s.environ == {'123': 'environ'} async def test_handle_disconnect_unknown_client(self, eio): mgr = self._get_mock_manager() s = async_server.AsyncServer(client_manager=mgr) - await s._handle_eio_disconnect('123') + await s._handle_eio_disconnect('123', 'foo') async def test_handle_event(self, eio): eio.return_value.send = mock.AsyncMock() @@ -624,7 +653,7 @@ class TestAsyncServer: await s._handle_eio_message('123', '2/bar,["msg","a","b"]') await s._handle_eio_message('123', '2/foo,["my message","a","b","c"]') await s._handle_eio_message('123', '2/bar,["my message","a","b","c"]') - await s._trigger_event('disconnect', '/bar', sid_bar) + await s._trigger_event('disconnect', '/bar', sid_bar, s.reason.UNKNOWN) connect_star_handler.assert_called_once_with('/bar', sid_bar) msg_foo_handler.assert_called_once_with(sid_foo, 'a', 'b') msg_star_handler.assert_called_once_with('/bar', sid_bar, 'a', 'b') @@ -884,8 +913,8 @@ class TestAsyncServer: def on_connect(self, sid, environ): result['result'] = (sid, environ) - async def on_disconnect(self, sid): - result['result'] = ('disconnect', sid) + async def on_disconnect(self, sid, reason): + result['result'] = ('disconnect', sid, reason) async def on_foo(self, sid, data): result['result'] = (sid, data) @@ -908,7 +937,8 @@ class TestAsyncServer: await s._handle_eio_message('123', '2/foo,["baz","a","b"]') assert result['result'] == ('a', 'b') await s.disconnect('1', '/foo') - assert result['result'] == ('disconnect', '1') + assert result['result'] == ('disconnect', '1', + s.reason.SERVER_DISCONNECT) async def test_catchall_namespace_handler(self, eio): eio.return_value.send = mock.AsyncMock() diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 57ddc2f..5ba119a 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -39,7 +39,7 @@ class TestServer: def foo(): pass - def bar(): + def bar(reason): pass s.on('disconnect', bar) @@ -510,8 +510,21 @@ class TestServer: s.on('disconnect', handler) s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0') - s._handle_eio_disconnect('123') - handler.assert_called_once_with('1') + s._handle_eio_disconnect('123', 'foo') + handler.assert_called_once_with('1', 'foo') + s.manager.disconnect.assert_called_once_with('1', '/', + ignore_queue=True) + assert s.environ == {} + + def test_handle_legacy_disconnect(self, eio): + s = server.Server() + s.manager.disconnect = mock.MagicMock() + handler = mock.MagicMock(side_effect=[TypeError, None]) + s.on('disconnect', handler) + s._handle_eio_connect('123', 'environ') + s._handle_eio_message('123', '0') + s._handle_eio_disconnect('123', 'foo') + handler.assert_called_with('1') s.manager.disconnect.assert_called_once_with('1', '/', ignore_queue=True) assert s.environ == {} @@ -524,9 +537,9 @@ class TestServer: s.on('disconnect', handler_namespace, namespace='/foo') s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0/foo,') - s._handle_eio_disconnect('123') + s._handle_eio_disconnect('123', 'foo') handler.assert_not_called() - handler_namespace.assert_called_once_with('1') + handler_namespace.assert_called_once_with('1', 'foo') assert s.environ == {} def test_handle_disconnect_only_namespace(self, eio): @@ -539,13 +552,14 @@ class TestServer: s._handle_eio_message('123', '0/foo,') s._handle_eio_message('123', '1/foo,') assert handler.call_count == 0 - handler_namespace.assert_called_once_with('1') + handler_namespace.assert_called_once_with( + '1', s.reason.CLIENT_DISCONNECT) assert s.environ == {'123': 'environ'} def test_handle_disconnect_unknown_client(self, eio): mgr = mock.MagicMock() s = server.Server(client_manager=mgr) - s._handle_eio_disconnect('123') + s._handle_eio_disconnect('123', 'foo') def test_handle_event(self, eio): s = server.Server(async_handlers=False) @@ -596,7 +610,7 @@ class TestServer: s._handle_eio_message('123', '2/bar,["msg","a","b"]') s._handle_eio_message('123', '2/foo,["my message","a","b","c"]') s._handle_eio_message('123', '2/bar,["my message","a","b","c"]') - s._trigger_event('disconnect', '/bar', sid_bar) + s._trigger_event('disconnect', '/bar', sid_bar, s.reason.UNKNOWN) connect_star_handler.assert_called_once_with('/bar', sid_bar) msg_foo_handler.assert_called_once_with(sid_foo, 'a', 'b') msg_star_handler.assert_called_once_with('/bar', sid_bar, 'a', 'b') @@ -825,8 +839,8 @@ class TestServer: def on_connect(self, sid, environ): result['result'] = (sid, environ) - def on_disconnect(self, sid): - result['result'] = ('disconnect', sid) + def on_disconnect(self, sid, reason): + result['result'] = ('disconnect', sid, reason) def on_foo(self, sid, data): result['result'] = (sid, data) @@ -849,7 +863,8 @@ class TestServer: s._handle_eio_message('123', '2/foo,["baz","a","b"]') assert result['result'] == ('a', 'b') s.disconnect('1', '/foo') - assert result['result'] == ('disconnect', '1') + assert result['result'] == ('disconnect', '1', + s.reason.SERVER_DISCONNECT) def test_catchall_namespace_handler(self, eio): result = {}