diff --git a/docs/client.rst b/docs/client.rst index aea9aba..5350bcb 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -253,25 +253,49 @@ or can also be coroutines:: If the server includes arguments with an event, those are passed to the handler function as arguments. -Catch-All Event Handlers -~~~~~~~~~~~~~~~~~~~~~~~~ +Catch-All Event and Namespace Handlers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ A "catch-all" event handler is invoked for any events that do not have an event handler. You can define a catch-all handler using ``'*'`` as event name:: @sio.on('*') - def catch_all(event, data): - pass + def any_event(event, sid, data): + pass -Asyncio clients can also use a coroutine:: +Asyncio servers can also use a coroutine:: @sio.on('*') - async def catch_all(event, data): - pass + async def any_event(event, sid, data): + pass A catch-all event handler receives the event name as a first argument. The remaining arguments are the same as for a regular event handler. +The ``connect`` and ``disconnect`` events have to be defined explicitly and are +not invoked on a catch-all event handler. + +Similarily, a "catch-all" namespace handler is invoked for any connected +namespaces that do not have an explicitly defined event handler. As with +catch-all events, ``'*'`` is used in place of a namespace:: + + @sio.on('my_event', namespace='*') + def my_event_any_namespace(namespace, sid, data): + pass + +For these events, the namespace is passed as first argument, followed by the +regular arguments of the event. + +Lastly, it is also possible to define a "catch-all" handler for all events on +all namespaces:: + + @sio.on('*', namespace='*') + def any_event_any_namespace(event, namespace, sid, data): + pass + +Event handlers with catch-all events and namespaces receive the event name and +the namespace as first and second arguments. + Connect, Connect Error and Disconnect Event Handlers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/server.rst b/docs/server.rst index c91a98e..b53647a 100644 --- a/docs/server.rst +++ b/docs/server.rst @@ -178,21 +178,21 @@ The ``sid`` argument is the Socket.IO session id, a unique identifier of each client connection. All the events sent by a given client will have the same ``sid`` value. -Catch-All Event Handlers ------------------------- +Catch-All Event and Namespace Handlers +-------------------------------------- A "catch-all" event handler is invoked for any events that do not have an event handler. You can define a catch-all handler using ``'*'`` as event name:: @sio.on('*') - def catch_all(event, sid, data): - pass + def any_event(event, sid, data): + pass Asyncio servers can also use a coroutine:: @sio.on('*') - async def catch_all(event, sid, data): - pass + async def any_event(event, sid, data): + pass A catch-all event handler receives the event name as a first argument. The remaining arguments are the same as for a regular event handler. @@ -200,6 +200,27 @@ remaining arguments are the same as for a regular event handler. The ``connect`` and ``disconnect`` events have to be defined explicitly and are not invoked on a catch-all event handler. +Similarily, a "catch-all" namespace handler is invoked for any connected +namespaces that do not have an explicitly defined event handler. As with +catch-all events, ``'*'`` is used in place of a namespace:: + + @sio.on('my_event', namespace='*') + def my_event_any_namespace(namespace, sid, data): + pass + +For these events, the namespace is passed as first argument, followed by the +regular arguments of the event. + +Lastly, it is also possible to define a "catch-all" handler for all events on +all namespaces:: + + @sio.on('*', namespace='*') + def any_event_any_namespace(event, namespace, sid, data): + pass + +Event handlers with catch-all events and namespaces receive the event name and +the namespace as first and second arguments. + Connect and Disconnect Event Handlers ------------------------------------- diff --git a/src/socketio/async_client.py b/src/socketio/async_client.py index 88a6c4c..394afa1 100644 --- a/src/socketio/async_client.py +++ b/src/socketio/async_client.py @@ -429,28 +429,21 @@ class AsyncClient(base_client.BaseClient): async def _trigger_event(self, event, namespace, *args): """Invoke an application event handler.""" # first see if we have an explicit handler for the event - if namespace in self.handlers: - handler = None - if event in self.handlers[namespace]: - handler = self.handlers[namespace][event] - elif event not in self.reserved_events and \ - '*' in self.handlers[namespace]: - handler = self.handlers[namespace]['*'] - args = (event, *args) - if handler: - if asyncio.iscoroutinefunction(handler): - try: - ret = await handler(*args) - except asyncio.CancelledError: # pragma: no cover - ret = None - else: - ret = handler(*args) - return ret + handler, args = self._get_event_handler(event, namespace, args) + if handler: + if asyncio.iscoroutinefunction(handler): + try: + ret = await handler(*args) + except asyncio.CancelledError: # pragma: no cover + ret = None + else: + ret = handler(*args) + return ret # or else, forward the event to a namepsace handler if one exists - elif namespace in self.namespace_handlers: - return await self.namespace_handlers[namespace].trigger_event( - event, *args) + handler, args = self._get_namespace_handler(namespace, args) + if handler: + return await handler.trigger_event(event, *args) async def _handle_reconnect(self): if self._reconnect_abort is None: # pragma: no cover diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index 99af067..131a9c0 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -617,30 +617,22 @@ class AsyncServer(base_server.BaseServer): async def _trigger_event(self, event, namespace, *args): """Invoke an application event handler.""" # first see if we have an explicit handler for the event - if namespace in self.handlers: - handler = None - if event in self.handlers[namespace]: - handler = self.handlers[namespace][event] - elif event not in self.reserved_events and \ - '*' in self.handlers[namespace]: - handler = self.handlers[namespace]['*'] - args = (event, *args) - if handler: - if asyncio.iscoroutinefunction(handler): - try: - ret = await handler(*args) - except asyncio.CancelledError: # pragma: no cover - ret = None - else: - ret = handler(*args) - return ret + handler, args = self._get_event_handler(event, namespace, args) + if handler: + if asyncio.iscoroutinefunction(handler): + try: + ret = await handler(*args) + except asyncio.CancelledError: # pragma: no cover + ret = None else: - return self.not_handled - - # or else, forward the event to a namepsace handler if one exists - elif namespace in self.namespace_handlers: # pragma: no branch - return await self.namespace_handlers[namespace].trigger_event( - event, *args) + ret = handler(*args) + return ret + # or else, forward the event to a namespace handler if one exists + handler, args = self._get_namespace_handler(namespace, args) + if handler: + return await handler.trigger_event(event, *args) + else: + return self.not_handled async def _handle_eio_connect(self, eio_sid, environ): """Handle the Engine.IO connection event.""" diff --git a/src/socketio/base_client.py b/src/socketio/base_client.py index 95fea1e..cd007cc 100644 --- a/src/socketio/base_client.py +++ b/src/socketio/base_client.py @@ -219,6 +219,46 @@ class BaseClient: """ return self.eio.transport() + def _get_event_handler(self, event, namespace, args): + # return the appropriate application event handler + # + # Resolution priority: + # - self.handlers[namespace][event] + # - self.handlers[namespace]["*"] + # - self.handlers["*"][event] + # - self.handlers["*"]["*"] + handler = None + if namespace in self.handlers: + if event in self.handlers[namespace]: + handler = self.handlers[namespace][event] + elif event not in self.reserved_events and \ + '*' in self.handlers[namespace]: + handler = self.handlers[namespace]['*'] + args = (event, *args) + elif '*' in self.handlers: + if event in self.handlers['*']: + handler = self.handlers['*'][event] + args = (namespace, *args) + elif event not in self.reserved_events and \ + '*' in self.handlers['*']: + handler = self.handlers['*']['*'] + args = (event, namespace, *args) + return handler, args + + def _get_namespace_handler(self, namespace, args): + # Return the appropriate application event handler. + # + # Resolution priority: + # - self.namespace_handlers[namespace] + # - self.namespace_handlers["*"] + handler = None + if namespace in self.namespace_handlers: + handler = self.namespace_handlers[namespace] + elif '*' in self.namespace_handlers: + handler = self.namespace_handlers['*'] + args = (namespace, *args) + return handler, args + def _generate_ack_id(self, namespace, callback): """Generate a unique identifier for an ACK packet.""" namespace = namespace or '/' diff --git a/src/socketio/base_server.py b/src/socketio/base_server.py index f8c9000..213158b 100644 --- a/src/socketio/base_server.py +++ b/src/socketio/base_server.py @@ -196,6 +196,48 @@ class BaseServer: eio_sid = self.manager.eio_sid_from_sid(sid, namespace or '/') return self.environ.get(eio_sid) + def _get_event_handler(self, event, namespace, args): + # Return the appropriate application event handler + # + # Resolution priority: + # - self.handlers[namespace][event] + # - self.handlers[namespace]["*"] + # - self.handlers["*"][event] + # - self.handlers["*"]["*"] + handler = None + print(event, namespace) + print(namespace in self.handlers) + if namespace in self.handlers: + if event in self.handlers[namespace]: + handler = self.handlers[namespace][event] + elif event not in self.reserved_events and \ + '*' in self.handlers[namespace]: + handler = self.handlers[namespace]['*'] + args = (event, *args) + elif '*' in self.handlers: + if event in self.handlers['*']: + handler = self.handlers['*'][event] + args = (namespace, *args) + elif event not in self.reserved_events and \ + '*' in self.handlers['*']: + handler = self.handlers['*']['*'] + args = (event, namespace, *args) + return handler, args + + def _get_namespace_handler(self, namespace, args): + # Return the appropriate application event handler. + # + # Resolution priority: + # - self.namespace_handlers[namespace] + # - self.namespace_handlers["*"] + handler = None + if namespace in self.namespace_handlers: + handler = self.namespace_handlers[namespace] + elif '*' in self.namespace_handlers: + handler = self.namespace_handlers['*'] + args = (namespace, *args) + return handler, args + def _handle_eio_connect(self): # pragma: no cover raise NotImplementedError() diff --git a/src/socketio/client.py b/src/socketio/client.py index 9f730a7..75d67dd 100644 --- a/src/socketio/client.py +++ b/src/socketio/client.py @@ -404,17 +404,14 @@ class Client(base_client.BaseClient): def _trigger_event(self, event, namespace, *args): """Invoke an application event handler.""" # first see if we have an explicit handler for the event - if namespace in self.handlers: - if event in self.handlers[namespace]: - return self.handlers[namespace][event](*args) - elif event not in self.reserved_events and \ - '*' in self.handlers[namespace]: - return self.handlers[namespace]['*'](event, *args) + handler, args = self._get_event_handler(event, namespace, args) + if handler: + return handler(*args) # or else, forward the event to a namespace handler if one exists - elif namespace in self.namespace_handlers: - return self.namespace_handlers[namespace].trigger_event( - event, *args) + handler, args = self._get_namespace_handler(namespace, args) + if handler: + return handler.trigger_event(event, *args) def _handle_reconnect(self): if self._reconnect_abort is None: # pragma: no cover diff --git a/src/socketio/server.py b/src/socketio/server.py index 2081337..a331b76 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -604,19 +604,15 @@ class Server(base_server.BaseServer): def _trigger_event(self, event, namespace, *args): """Invoke an application event handler.""" # first see if we have an explicit handler for the event - if namespace in self.handlers: - if event in self.handlers[namespace]: - return self.handlers[namespace][event](*args) - elif event not in self.reserved_events and \ - '*' in self.handlers[namespace]: - return self.handlers[namespace]['*'](event, *args) - else: - return self.not_handled - + handler, args = self._get_event_handler(event, namespace, args) + if handler: + return handler(*args) # or else, forward the event to a namespace handler if one exists - elif namespace in self.namespace_handlers: # pragma: no branch - return self.namespace_handlers[namespace].trigger_event( - event, *args) + handler, args = self._get_namespace_handler(namespace, args) + if handler: + return handler.trigger_event(event, *args) + else: + return self.not_handled def _handle_eio_connect(self, eio_sid, environ): """Handle the Engine.IO connection event.""" diff --git a/tests/async/test_client.py b/tests/async/test_client.py index 548b71c..d2a2b8c 100644 --- a/tests/async/test_client.py +++ b/tests/async/test_client.py @@ -857,6 +857,38 @@ class TestAsyncClient(unittest.TestCase): _run(c._trigger_event('foo', '/', 1, '2')) assert result == [1, '2'] + def test_trigger_event_with_catchall_class_namespace(self): + result = {} + + class MyNamespace(async_namespace.AsyncClientNamespace): + def on_connect(self, ns): + result['result'] = (ns,) + + def on_disconnect(self, ns): + result['result'] = ('disconnect', ns) + + def on_foo(self, ns, data): + result['result'] = (ns, data) + + def on_bar(self, ns): + result['result'] = 'bar' + ns + + def on_baz(self, ns, data1, data2): + result['result'] = (ns, data1, data2) + + c = async_client.AsyncClient() + c.register_namespace(MyNamespace('*')) + _run(c._trigger_event('connect', '/foo')) + assert result['result'] == ('/foo',) + _run(c._trigger_event('foo', '/foo', 'a')) + assert result['result'] == ('/foo', 'a') + _run(c._trigger_event('bar', '/foo')) + assert result['result'] == 'bar/foo' + _run(c._trigger_event('baz', '/foo', 'a', 'b')) + assert result['result'] == ('/foo', 'a', 'b') + _run(c._trigger_event('disconnect', '/foo')) + assert result['result'] == ('disconnect', '/foo') + def test_trigger_event_unknown_namespace(self): c = async_client.AsyncClient() result = [] diff --git a/tests/async/test_server.py b/tests/async/test_server.py index 2f84b5f..bc83bdc 100644 --- a/tests/async/test_server.py +++ b/tests/async/test_server.py @@ -621,6 +621,35 @@ class TestAsyncServer(unittest.TestCase): catchall_handler.assert_called_once_with( 'my message', sid, 'a', 'b', 'c') + def test_handle_event_with_catchall_namespace(self, eio): + eio.return_value.send = AsyncMock() + s = async_server.AsyncServer(async_handlers=False) + sid_foo = _run(s.manager.connect('123', '/foo')) + sid_bar = _run(s.manager.connect('123', '/bar')) + connect_star_handler = mock.MagicMock() + msg_foo_handler = mock.MagicMock() + msg_star_handler = mock.MagicMock() + star_foo_handler = mock.MagicMock() + star_star_handler = mock.MagicMock() + s.on('connect', connect_star_handler, namespace='*') + s.on('msg', msg_foo_handler, namespace='/foo') + s.on('msg', msg_star_handler, namespace='*') + s.on('*', star_foo_handler, namespace='/foo') + s.on('*', star_star_handler, namespace='*') + _run(s._trigger_event('connect', '/bar', sid_bar)) + _run(s._handle_eio_message('123', '2/foo,["msg","a","b"]')) + _run(s._handle_eio_message('123', '2/bar,["msg","a","b"]')) + _run(s._handle_eio_message('123', '2/foo,["my message","a","b","c"]')) + _run(s._handle_eio_message('123', '2/bar,["my message","a","b","c"]')) + _run(s._trigger_event('disconnect', '/bar', sid_bar)) + connect_star_handler.assert_called_once_with('/bar', sid_bar) + msg_foo_handler.assert_called_once_with(sid_foo, 'a', 'b') + msg_star_handler.assert_called_once_with('/bar', sid_bar, 'a', 'b') + star_foo_handler.assert_called_once_with( + 'my message', sid_foo, 'a', 'b', 'c') + star_star_handler.assert_called_once_with( + 'my message', '/bar', sid_bar, 'a', 'b', 'c') + def test_handle_event_with_disconnected_namespace(self, eio): eio.return_value.send = AsyncMock() s = async_server.AsyncServer(async_handlers=False) @@ -904,6 +933,40 @@ class TestAsyncServer(unittest.TestCase): _run(s.disconnect('1', '/foo')) assert result['result'] == ('disconnect', '1') + def test_catchall_namespace_handler(self, eio): + eio.return_value.send = AsyncMock() + result = {} + + class MyNamespace(async_namespace.AsyncNamespace): + def on_connect(self, ns, sid, environ): + result['result'] = (sid, ns, environ) + + async def on_disconnect(self, ns, sid): + result['result'] = ('disconnect', sid, ns) + + async def on_foo(self, ns, sid, data): + result['result'] = (sid, ns, data) + + def on_bar(self, ns, sid): + result['result'] = 'bar' + ns + + async def on_baz(self, ns, sid, data1, data2): + result['result'] = (ns, data1, data2) + + s = async_server.AsyncServer(async_handlers=False, namespaces='*') + s.register_namespace(MyNamespace('*')) + _run(s._handle_eio_connect('123', 'environ')) + _run(s._handle_eio_message('123', '0/foo,')) + assert result['result'] == ('1', '/foo', 'environ') + _run(s._handle_eio_message('123', '2/foo,["foo","a"]')) + assert result['result'] == ('1', '/foo', 'a') + _run(s._handle_eio_message('123', '2/foo,["bar"]')) + assert result['result'] == 'bar/foo' + _run(s._handle_eio_message('123', '2/foo,["baz","a","b"]')) + assert result['result'] == ('/foo', 'a', 'b') + _run(s.disconnect('1', '/foo')) + assert result['result'] == ('disconnect', '1', '/foo') + def test_bad_namespace_handler(self, eio): class Dummy(object): pass diff --git a/tests/common/test_client.py b/tests/common/test_client.py index 791cb43..637e7d5 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -970,6 +970,64 @@ class TestClient(unittest.TestCase): handler.assert_called_once_with(1, '2') catchall_handler.assert_called_once_with('bar', 1, '2', 3) + def test_trigger_event_with_catchall_namespace(self): + c = client.Client() + connect_star_handler = mock.MagicMock() + msg_foo_handler = mock.MagicMock() + msg_star_handler = mock.MagicMock() + star_foo_handler = mock.MagicMock() + star_star_handler = mock.MagicMock() + c.on('connect', connect_star_handler, namespace='*') + c.on('msg', msg_foo_handler, namespace='/foo') + c.on('msg', msg_star_handler, namespace='*') + c.on('*', star_foo_handler, namespace='/foo') + c.on('*', star_star_handler, namespace='*') + c._trigger_event('connect', '/bar') + c._trigger_event('msg', '/foo', 'a', 'b') + c._trigger_event('msg', '/bar', 'a', 'b') + c._trigger_event('my message', '/foo', 'a', 'b', 'c') + c._trigger_event('my message', '/bar', 'a', 'b', 'c') + c._trigger_event('disconnect', '/bar') + connect_star_handler.assert_called_once_with('/bar') + msg_foo_handler.assert_called_once_with('a', 'b') + msg_star_handler.assert_called_once_with('/bar', 'a', 'b') + star_foo_handler.assert_called_once_with( + 'my message', 'a', 'b', 'c') + star_star_handler.assert_called_once_with( + 'my message', '/bar', 'a', 'b', 'c') + + def test_trigger_event_with_catchall_namespace_handler(self): + result = {} + + class MyNamespace(namespace.ClientNamespace): + def on_connect(self, ns): + result['result'] = (ns,) + + def on_disconnect(self, ns): + result['result'] = ('disconnect', ns) + + def on_foo(self, ns, data): + result['result'] = (ns, data) + + def on_bar(self, ns): + result['result'] = 'bar' + ns + + def on_baz(self, ns, data1, data2): + result['result'] = (ns, data1, data2) + + c = client.Client() + c.register_namespace(MyNamespace('*')) + c._trigger_event('connect', '/foo') + assert result['result'] == ('/foo',) + c._trigger_event('foo', '/foo', 'a') + assert result['result'] == ('/foo', 'a') + c._trigger_event('bar', '/foo') + assert result['result'] == 'bar/foo' + c._trigger_event('baz', '/foo', 'a', 'b') + assert result['result'] == ('/foo', 'a', 'b') + c._trigger_event('disconnect', '/foo') + assert result['result'] == ('disconnect', '/foo') + def test_trigger_event_class_namespace(self): c = client.Client() result = [] diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 08c59ac..8f3a356 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -574,6 +574,34 @@ class TestServer(unittest.TestCase): catchall_handler.assert_called_once_with( 'my message', '1', 'a', 'b', 'c') + def test_handle_event_with_catchall_namespace(self, eio): + s = server.Server(async_handlers=False) + sid_foo = s.manager.connect('123', '/foo') + sid_bar = s.manager.connect('123', '/bar') + connect_star_handler = mock.MagicMock() + msg_foo_handler = mock.MagicMock() + msg_star_handler = mock.MagicMock() + star_foo_handler = mock.MagicMock() + star_star_handler = mock.MagicMock() + s.on('connect', connect_star_handler, namespace='*') + s.on('msg', msg_foo_handler, namespace='/foo') + s.on('msg', msg_star_handler, namespace='*') + s.on('*', star_foo_handler, namespace='/foo') + s.on('*', star_star_handler, namespace='*') + s._trigger_event('connect', '/bar', sid_bar) + s._handle_eio_message('123', '2/foo,["msg","a","b"]') + s._handle_eio_message('123', '2/bar,["msg","a","b"]') + s._handle_eio_message('123', '2/foo,["my message","a","b","c"]') + s._handle_eio_message('123', '2/bar,["my message","a","b","c"]') + s._trigger_event('disconnect', '/bar', sid_bar) + connect_star_handler.assert_called_once_with('/bar', sid_bar) + msg_foo_handler.assert_called_once_with(sid_foo, 'a', 'b') + msg_star_handler.assert_called_once_with('/bar', sid_bar, 'a', 'b') + star_foo_handler.assert_called_once_with( + 'my message', sid_foo, 'a', 'b', 'c') + star_star_handler.assert_called_once_with( + 'my message', '/bar', sid_bar, 'a', 'b', 'c') + def test_handle_event_with_disconnected_namespace(self, eio): s = server.Server(async_handlers=False) s.manager.connect('123', '/foo') @@ -815,6 +843,39 @@ class TestServer(unittest.TestCase): s.disconnect('1', '/foo') assert result['result'] == ('disconnect', '1') + def test_catchall_namespace_handler(self, eio): + result = {} + + class MyNamespace(namespace.Namespace): + def on_connect(self, ns, sid, environ): + result['result'] = (sid, ns, environ) + + def on_disconnect(self, ns, sid): + result['result'] = ('disconnect', sid, ns) + + def on_foo(self, ns, sid, data): + result['result'] = (sid, ns, data) + + def on_bar(self, ns, sid): + result['result'] = 'bar' + ns + + def on_baz(self, ns, sid, data1, data2): + result['result'] = (ns, data1, data2) + + s = server.Server(async_handlers=False, namespaces='*') + s.register_namespace(MyNamespace('*')) + s._handle_eio_connect('123', 'environ') + s._handle_eio_message('123', '0/foo,') + assert result['result'] == ('1', '/foo', 'environ') + s._handle_eio_message('123', '2/foo,["foo","a"]') + assert result['result'] == ('1', '/foo', 'a') + s._handle_eio_message('123', '2/foo,["bar"]') + assert result['result'] == 'bar/foo' + s._handle_eio_message('123', '2/foo,["baz","a","b"]') + assert result['result'] == ('/foo', 'a', 'b') + s.disconnect('1', '/foo') + assert result['result'] == ('disconnect', '1', '/foo') + def test_bad_namespace_handler(self, eio): class Dummy(object): pass