From 28569d48ad74d5414a0d2a8f69d7540dbdddf066 Mon Sep 17 00:00:00 2001
From: Miguel Grinberg <miguel.grinberg@gmail.com>
Date: Fri, 3 Sep 2021 11:06:19 +0100
Subject: [PATCH] Catch-all event handlers

---
 docs/client.rst                      | 22 ++++++++++++++++++++++
 docs/server.rst                      | 22 ++++++++++++++++++++++
 src/socketio/asyncio_client.py       | 25 ++++++++++++++++---------
 src/socketio/asyncio_server.py       | 26 ++++++++++++++++----------
 src/socketio/client.py               |  7 +++++--
 src/socketio/server.py               |  7 +++++--
 tests/asyncio/test_asyncio_client.py |  8 ++++++++
 tests/asyncio/test_asyncio_server.py | 18 ++++++++++++++----
 tests/common/test_client.py          |  8 ++++++++
 tests/common/test_server.py          | 18 ++++++++++++++----
 10 files changed, 130 insertions(+), 31 deletions(-)

diff --git a/docs/client.rst b/docs/client.rst
index d777578..5bf85ad 100644
--- a/docs/client.rst
+++ b/docs/client.rst
@@ -65,6 +65,28 @@ or can also be coroutines::
     async def message(data):
         print('I received a message!')
 
+Catch-All Event Handlers
+------------------------
+
+A "catch-all" event handler is invoked for any events that do not have an
+event handler. You can define a catch-all handler using ``'*'`` as event name::
+
+   @sio.on('*')
+   def catch_all(event, sid, data):
+       pass
+
+Asyncio clients can also use a coroutine::
+
+   @sio.on('*')
+   async def catch_all(event, sid, data):
+      pass
+
+A catch-all event handler receives the event name as a first argument. The
+remaining arguments are the same as for a regular event handler.
+
+Connect, Connect Error and Disconnect Event Handlers
+----------------------------------------------------
+
 The ``connect``, ``connect_error`` and ``disconnect`` events are special; they 
 are invoked automatically when a client connects or disconnects from the
 server::
diff --git a/docs/server.rst b/docs/server.rst
index a2fc365..e553696 100644
--- a/docs/server.rst
+++ b/docs/server.rst
@@ -178,6 +178,28 @@ The ``sid`` argument is the Socket.IO session id, a unique identifier of each
 client connection. All the events sent by a given client will have the same
 ``sid`` value.
 
+Catch-All Event Handlers
+------------------------
+
+A "catch-all" event handler is invoked for any events that do not have an
+event handler. You can define a catch-all handler using ``'*'`` as event name::
+
+   @sio.on('*')
+   def catch_all(event, sid, data):
+       pass
+
+Asyncio servers can also use a coroutine::
+
+   @sio.on('*')
+   async def catch_all(event, sid, data):
+      pass
+
+A catch-all event handler receives the event name as a first argument. The
+remaining arguments are the same as for a regular event handler.
+
+Connect and Disconnect Event Handlers
+-------------------------------------
+
 The ``connect`` and ``disconnect`` events are special; they are invoked
 automatically when a client connects or disconnects from the server::
 
diff --git a/src/socketio/asyncio_client.py b/src/socketio/asyncio_client.py
index 63d0899..461e96f 100644
--- a/src/socketio/asyncio_client.py
+++ b/src/socketio/asyncio_client.py
@@ -418,15 +418,22 @@ class AsyncClient(client.Client):
     async def _trigger_event(self, event, namespace, *args):
         """Invoke an application event handler."""
         # first see if we have an explicit handler for the event
-        if namespace in self.handlers and event in self.handlers[namespace]:
-            if asyncio.iscoroutinefunction(self.handlers[namespace][event]):
-                try:
-                    ret = await self.handlers[namespace][event](*args)
-                except asyncio.CancelledError:  # pragma: no cover
-                    ret = None
-            else:
-                ret = self.handlers[namespace][event](*args)
-            return ret
+        if namespace in self.handlers:
+            handler = None
+            if event in self.handlers[namespace]:
+                handler = self.handlers[namespace][event]
+            elif '*' in self.handlers[namespace]:
+                handler = self.handlers[namespace]['*']
+                args = (event, *args)
+            if handler:
+                if asyncio.iscoroutinefunction(handler):
+                    try:
+                        ret = await handler(*args)
+                    except asyncio.CancelledError:  # pragma: no cover
+                        ret = None
+                else:
+                    ret = handler(*args)
+                return ret
 
         # or else, forward the event to a namepsace handler if one exists
         elif namespace in self.namespace_handlers:
diff --git a/src/socketio/asyncio_server.py b/src/socketio/asyncio_server.py
index ffdcdec..7e1b889 100644
--- a/src/socketio/asyncio_server.py
+++ b/src/socketio/asyncio_server.py
@@ -524,16 +524,22 @@ class AsyncServer(server.Server):
     async def _trigger_event(self, event, namespace, *args):
         """Invoke an application event handler."""
         # first see if we have an explicit handler for the event
-        if namespace in self.handlers and event in self.handlers[namespace]:
-            if asyncio.iscoroutinefunction(self.handlers[namespace][event]) \
-                    is True:
-                try:
-                    ret = await self.handlers[namespace][event](*args)
-                except asyncio.CancelledError:  # pragma: no cover
-                    ret = None
-            else:
-                ret = self.handlers[namespace][event](*args)
-            return ret
+        if namespace in self.handlers:
+            handler = None
+            if event in self.handlers[namespace]:
+                handler = self.handlers[namespace][event]
+            elif '*' in self.handlers[namespace]:
+                handler = self.handlers[namespace]['*']
+                args = (event, *args)
+            if handler:
+                if asyncio.iscoroutinefunction(handler):
+                    try:
+                        ret = await handler(*args)
+                    except asyncio.CancelledError:  # pragma: no cover
+                        ret = None
+                else:
+                    ret = handler(*args)
+                return ret
 
         # or else, forward the event to a namepsace handler if one exists
         elif namespace in self.namespace_handlers:
diff --git a/src/socketio/client.py b/src/socketio/client.py
index b631608..b30fea7 100644
--- a/src/socketio/client.py
+++ b/src/socketio/client.py
@@ -609,8 +609,11 @@ class Client(object):
     def _trigger_event(self, event, namespace, *args):
         """Invoke an application event handler."""
         # first see if we have an explicit handler for the event
-        if namespace in self.handlers and event in self.handlers[namespace]:
-            return self.handlers[namespace][event](*args)
+        if namespace in self.handlers:
+            if event in self.handlers[namespace]:
+                return self.handlers[namespace][event](*args)
+            elif '*' in self.handlers[namespace]:
+                return self.handlers[namespace]['*'](event, *args)
 
         # or else, forward the event to a namespace handler if one exists
         elif namespace in self.namespace_handlers:
diff --git a/src/socketio/server.py b/src/socketio/server.py
index 95c7134..d4dd22f 100644
--- a/src/socketio/server.py
+++ b/src/socketio/server.py
@@ -732,8 +732,11 @@ class Server(object):
     def _trigger_event(self, event, namespace, *args):
         """Invoke an application event handler."""
         # first see if we have an explicit handler for the event
-        if namespace in self.handlers and event in self.handlers[namespace]:
-            return self.handlers[namespace][event](*args)
+        if namespace in self.handlers:
+            if event in self.handlers[namespace]:
+                return self.handlers[namespace][event](*args)
+            elif '*' in self.handlers[namespace]:
+                return self.handlers[namespace]['*'](event, *args)
 
         # or else, forward the event to a namespace handler if one exists
         elif namespace in self.namespace_handlers:
diff --git a/tests/asyncio/test_asyncio_client.py b/tests/asyncio/test_asyncio_client.py
index 38abcdd..fe3414e 100644
--- a/tests/asyncio/test_asyncio_client.py
+++ b/tests/asyncio/test_asyncio_client.py
@@ -833,16 +833,24 @@ class TestAsyncClient(unittest.TestCase):
     def test_trigger_event(self):
         c = asyncio_client.AsyncClient()
         handler = mock.MagicMock()
+        catchall_handler = mock.MagicMock()
         c.on('foo', handler)
+        c.on('*', catchall_handler)
         _run(c._trigger_event('foo', '/', 1, '2'))
+        _run(c._trigger_event('bar', '/', 1, '2', 3))
         handler.assert_called_once_with(1, '2')
+        catchall_handler.assert_called_once_with('bar', 1, '2', 3)
 
     def test_trigger_event_namespace(self):
         c = asyncio_client.AsyncClient()
         handler = AsyncMock()
+        catchall_handler = AsyncMock()
         c.on('foo', handler, namespace='/bar')
+        c.on('*', catchall_handler, namespace='/bar')
         _run(c._trigger_event('foo', '/bar', 1, '2'))
+        _run(c._trigger_event('bar', '/bar', 1, '2', 3))
         handler.mock.assert_called_once_with(1, '2')
+        catchall_handler.mock.assert_called_once_with('bar', 1, '2', 3)
 
     def test_trigger_event_class_namespace(self):
         c = asyncio_client.AsyncClient()
diff --git a/tests/asyncio/test_asyncio_server.py b/tests/asyncio/test_asyncio_server.py
index a6c2667..824d4a5 100644
--- a/tests/asyncio/test_asyncio_server.py
+++ b/tests/asyncio/test_asyncio_server.py
@@ -618,18 +618,28 @@ class TestAsyncServer(unittest.TestCase):
         s = asyncio_server.AsyncServer(async_handlers=False)
         sid = s.manager.connect('123', '/')
         handler = AsyncMock()
-        s.on('my message', handler)
+        catchall_handler = AsyncMock()
+        s.on('msg', handler)
+        s.on('*', catchall_handler)
+        _run(s._handle_eio_message('123', '2["msg","a","b"]'))
         _run(s._handle_eio_message('123', '2["my message","a","b","c"]'))
-        handler.mock.assert_called_once_with(sid, 'a', 'b', 'c')
+        handler.mock.assert_called_once_with(sid, 'a', 'b')
+        catchall_handler.mock.assert_called_once_with(
+            'my message', sid, 'a', 'b', 'c')
 
     def test_handle_event_with_namespace(self, eio):
         eio.return_value.send = AsyncMock()
         s = asyncio_server.AsyncServer(async_handlers=False)
         sid = s.manager.connect('123', '/foo')
         handler = mock.MagicMock()
-        s.on('my message', handler, namespace='/foo')
+        catchall_handler = mock.MagicMock()
+        s.on('msg', handler, namespace='/foo')
+        s.on('*', catchall_handler, namespace='/foo')
+        _run(s._handle_eio_message('123', '2/foo,["msg","a","b"]'))
         _run(s._handle_eio_message('123', '2/foo,["my message","a","b","c"]'))
-        handler.assert_called_once_with(sid, 'a', 'b', 'c')
+        handler.assert_called_once_with(sid, 'a', 'b')
+        catchall_handler.assert_called_once_with(
+            'my message', sid, 'a', 'b', 'c')
 
     def test_handle_event_with_disconnected_namespace(self, eio):
         eio.return_value.send = AsyncMock()
diff --git a/tests/common/test_client.py b/tests/common/test_client.py
index 391187f..e87e86e 100644
--- a/tests/common/test_client.py
+++ b/tests/common/test_client.py
@@ -934,16 +934,24 @@ class TestClient(unittest.TestCase):
     def test_trigger_event(self):
         c = client.Client()
         handler = mock.MagicMock()
+        catchall_handler = mock.MagicMock()
         c.on('foo', handler)
+        c.on('*', catchall_handler)
         c._trigger_event('foo', '/', 1, '2')
+        c._trigger_event('bar', '/', 1, '2', 3)
         handler.assert_called_once_with(1, '2')
+        catchall_handler.assert_called_once_with('bar', 1, '2', 3)
 
     def test_trigger_event_namespace(self):
         c = client.Client()
         handler = mock.MagicMock()
+        catchall_handler = mock.MagicMock()
         c.on('foo', handler, namespace='/bar')
+        c.on('*', catchall_handler, namespace='/bar')
         c._trigger_event('foo', '/bar', 1, '2')
+        c._trigger_event('bar', '/bar', 1, '2', 3)
         handler.assert_called_once_with(1, '2')
+        catchall_handler.assert_called_once_with('bar', 1, '2', 3)
 
     def test_trigger_event_class_namespace(self):
         c = client.Client()
diff --git a/tests/common/test_server.py b/tests/common/test_server.py
index 1bdac11..3b89c3b 100644
--- a/tests/common/test_server.py
+++ b/tests/common/test_server.py
@@ -546,17 +546,27 @@ class TestServer(unittest.TestCase):
         s = server.Server(async_handlers=False)
         s.manager.connect('123', '/')
         handler = mock.MagicMock()
-        s.on('my message', handler)
+        catchall_handler = mock.MagicMock()
+        s.on('msg', handler)
+        s.on('*', catchall_handler)
+        s._handle_eio_message('123', '2["msg","a","b"]')
         s._handle_eio_message('123', '2["my message","a","b","c"]')
-        handler.assert_called_once_with('1', 'a', 'b', 'c')
+        handler.assert_called_once_with('1', 'a', 'b')
+        catchall_handler.assert_called_once_with(
+            'my message', '1', 'a', 'b', 'c')
 
     def test_handle_event_with_namespace(self, eio):
         s = server.Server(async_handlers=False)
         s.manager.connect('123', '/foo')
         handler = mock.MagicMock()
-        s.on('my message', handler, namespace='/foo')
+        catchall_handler = mock.MagicMock()
+        s.on('msg', handler, namespace='/foo')
+        s.on('*', catchall_handler, namespace='/foo')
+        s._handle_eio_message('123', '2/foo,["msg","a","b"]')
         s._handle_eio_message('123', '2/foo,["my message","a","b","c"]')
-        handler.assert_called_once_with('1', 'a', 'b', 'c')
+        handler.assert_called_once_with('1', 'a', 'b')
+        catchall_handler.assert_called_once_with(
+            'my message', '1', 'a', 'b', 'c')
 
     def test_handle_event_with_disconnected_namespace(self, eio):
         s = server.Server(async_handlers=False)