From 93e47ab590a0fff0b0c7b41d32807a9af3c6a525 Mon Sep 17 00:00:00 2001
From: Mathijs de Bruin <mathijs@mathijsfietst.nl>
Date: Mon, 25 Nov 2024 09:44:20 +0000
Subject: [PATCH] Type hints for event handler.

Prevents type errors in library use.
---
 src/socketio/base_server.py | 57 ++++++++++++++++++++++++++++++++-----
 1 file changed, 50 insertions(+), 7 deletions(-)

diff --git a/src/socketio/base_server.py b/src/socketio/base_server.py
index d5a353b..551d343 100644
--- a/src/socketio/base_server.py
+++ b/src/socketio/base_server.py
@@ -1,4 +1,5 @@
 import logging
+from typing import Any, Callable, Dict, Optional, ParamSpec, TypeVar, Union, overload
 
 from . import manager
 from . import base_namespace
@@ -6,13 +7,31 @@ from . import packet
 
 default_logger = logging.getLogger('socketio.server')
 
+HandlerParams = ParamSpec("HandlerParams")
+HandlerReturn = TypeVar("HandlerReturn")
+EventHandler = Callable[HandlerParams, HandlerReturn]
 
-class BaseServer:
-    reserved_events = ['connect', 'disconnect']
 
-    def __init__(self, client_manager=None, logger=False, serializer='default',
-                 json=None, async_handlers=True, always_connect=False,
-                 namespaces=None, **kwargs):
+class BaseServer:
+    handlers: Dict[str, Dict[str, Callable[..., Any]]]
+    namespace_handlers: Dict[
+        str, Any
+    ]  # Any is used here since base_namespace.BaseServerNamespace isn't imported
+    reserved_events: list[str] = ["connect", "disconnect"]
+    environ: Dict[str, Any]
+    _binary_packet: Dict[str, Any]
+
+    def __init__(
+        self,
+        client_manager=None,
+        logger=False,
+        serializer="default",
+        json=None,
+        async_handlers=True,
+        always_connect=False,
+        namespaces=None,
+        **kwargs,
+    ):
         engineio_options = kwargs
         engineio_logger = engineio_options.pop('engineio_logger', None)
         if engineio_logger is not None:
@@ -66,7 +85,31 @@ class BaseServer:
     def is_asyncio_based(self):
         return False
 
-    def on(self, event, handler=None, namespace=None):
+    @overload
+    def on(
+        self, event: str, handler: None = None, namespace: Optional[str] = None
+    ) -> Callable[[EventHandler], EventHandler]: ...
+
+    @overload
+    def on(
+        self,
+        event: str,
+        handler: EventHandler,
+        namespace: Optional[str] = None,
+    ) -> None: ...
+
+    def on(
+        self,
+        event: str,
+        handler: Optional[EventHandler] = None,
+        namespace: Optional[str] = None,
+    ) -> Union[
+        Callable[
+            [EventHandler],
+            EventHandler,
+        ],
+        None,
+    ]:
         """Register an event handler.
 
         :param event: The event name. It can be any string. The event names
@@ -116,7 +159,7 @@ class BaseServer:
         """
         namespace = namespace or '/'
 
-        def set_handler(handler):
+        def set_handler(handler: EventHandler) -> EventHandler:
             if namespace not in self.handlers:
                 self.handlers[namespace] = {}
             self.handlers[namespace][event] = handler