diff --git a/src/socketio/base_server.py b/src/socketio/base_server.py index d5a353b..551d343 100644 --- a/src/socketio/base_server.py +++ b/src/socketio/base_server.py @@ -1,4 +1,5 @@ import logging +from typing import Any, Callable, Dict, Optional, ParamSpec, TypeVar, Union, overload from . import manager from . import base_namespace @@ -6,13 +7,31 @@ from . import packet default_logger = logging.getLogger('socketio.server') +HandlerParams = ParamSpec("HandlerParams") +HandlerReturn = TypeVar("HandlerReturn") +EventHandler = Callable[HandlerParams, HandlerReturn] -class BaseServer: - reserved_events = ['connect', 'disconnect'] - def __init__(self, client_manager=None, logger=False, serializer='default', - json=None, async_handlers=True, always_connect=False, - namespaces=None, **kwargs): +class BaseServer: + handlers: Dict[str, Dict[str, Callable[..., Any]]] + namespace_handlers: Dict[ + str, Any + ] # Any is used here since base_namespace.BaseServerNamespace isn't imported + reserved_events: list[str] = ["connect", "disconnect"] + environ: Dict[str, Any] + _binary_packet: Dict[str, Any] + + def __init__( + self, + client_manager=None, + logger=False, + serializer="default", + json=None, + async_handlers=True, + always_connect=False, + namespaces=None, + **kwargs, + ): engineio_options = kwargs engineio_logger = engineio_options.pop('engineio_logger', None) if engineio_logger is not None: @@ -66,7 +85,31 @@ class BaseServer: def is_asyncio_based(self): return False - def on(self, event, handler=None, namespace=None): + @overload + def on( + self, event: str, handler: None = None, namespace: Optional[str] = None + ) -> Callable[[EventHandler], EventHandler]: ... + + @overload + def on( + self, + event: str, + handler: EventHandler, + namespace: Optional[str] = None, + ) -> None: ... + + def on( + self, + event: str, + handler: Optional[EventHandler] = None, + namespace: Optional[str] = None, + ) -> Union[ + Callable[ + [EventHandler], + EventHandler, + ], + None, + ]: """Register an event handler. :param event: The event name. It can be any string. The event names @@ -116,7 +159,7 @@ class BaseServer: """ namespace = namespace or '/' - def set_handler(handler): + def set_handler(handler: EventHandler) -> EventHandler: if namespace not in self.handlers: self.handlers[namespace] = {} self.handlers[namespace][event] = handler