From c66b28270ab470540b8b261cede93be5cb195f42 Mon Sep 17 00:00:00 2001 From: Konstantin Ponomarev Date: Wed, 13 Aug 2025 14:24:55 +0300 Subject: [PATCH] fix: type hints --- src/socketio/__init__.py | 2 + src/socketio/async_server.py | 167 +++++++++++++++++++-------------- src/socketio/base_namespace.py | 23 +++-- src/socketio/base_server.py | 94 +++++++++++-------- src/socketio/tornado.py | 13 ++- 5 files changed, 178 insertions(+), 121 deletions(-) diff --git a/src/socketio/__init__.py b/src/socketio/__init__.py index 5e74bf2..d2a4a89 100644 --- a/src/socketio/__init__.py +++ b/src/socketio/__init__.py @@ -17,6 +17,7 @@ from .redis_manager import RedisManager from .server import Server from .simple_client import SimpleClient from .tornado import get_tornado_handler +from .router import RouterSocketIO from .zmq_manager import ZmqManager __all__ = [ @@ -43,4 +44,5 @@ __all__ = [ "WSGIApp", "ZmqManager", "get_tornado_handler", + "RouterSocketIO", ] diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index c575abc..5355bfd 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -1,16 +1,30 @@ import asyncio +# pyright: reportMissingImports=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownParameterType=false +from typing import Any, AsyncContextManager, Callable, Dict, List, Optional, Set, Union, TYPE_CHECKING, Coroutine import engineio from . import async_manager, base_server, exceptions, packet +if TYPE_CHECKING: # pragma: no cover + from .async_admin import InstrumentedAsyncServer + # this set is used to keep references to background tasks to prevent them from # being garbage collected mid-execution. Solution taken from # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task -task_reference_holder = set() +task_reference_holder: Set[asyncio.Task[Any]] = set() class AsyncServer(base_server.BaseServer): + # Attribute type hints to aid static type checkers + manager: async_manager.AsyncManager + eio: Any + packet_class: Any + handlers: Dict[str, Dict[str, Any]] + namespace_handlers: Dict[str, Any] + namespaces: Union[List[str], str] + environ: Dict[str, Any] + _binary_packet: Dict[str, Any] """A Socket.IO server for asyncio. This class implements a fully compliant Socket.IO web server with support @@ -111,13 +125,13 @@ class AsyncServer(base_server.BaseServer): def __init__( self, - client_manager=None, - logger=False, - json=None, - async_handlers=True, - namespaces=None, - **kwargs, - ): + client_manager: Optional[async_manager.AsyncManager] = None, + logger: Union[bool, Any] = False, + json: Optional[Any] = None, + async_handlers: bool = True, + namespaces: Optional[Union[List[str], str]] = None, + **kwargs: Any, + ) -> None: if client_manager is None: client_manager = async_manager.AsyncManager() super().__init__( @@ -128,25 +142,26 @@ class AsyncServer(base_server.BaseServer): namespaces=namespaces, **kwargs, ) + # attributes are provided by the base class; runtime types are ensured - def is_asyncio_based(self): + def is_asyncio_based(self): # type: ignore[override] return True - def attach(self, app, socketio_path="socket.io"): + def attach(self, app: Any, socketio_path: str = "socket.io") -> None: """Attach the Socket.IO server to an application.""" self.eio.attach(app, socketio_path) async def emit( self, - event, - data=None, - to=None, - room=None, - skip_sid=None, - namespace=None, - callback=None, - ignore_queue=False, - ): + event: str, + data: Optional[Any] = None, + to: Optional[Union[str, List[str]]] = None, + room: Optional[Union[str, List[str]]] = None, + skip_sid: Optional[Union[str, List[str]]] = None, + namespace: Optional[str] = None, + callback: Optional[Callable[..., Any]] = None, + ignore_queue: bool = False, + ) -> None: """Emit a custom event to one or more connected clients. :param event: The event name. It can be any string. The event names @@ -207,14 +222,14 @@ class AsyncServer(base_server.BaseServer): async def send( self, - data, - to=None, - room=None, - skip_sid=None, - namespace=None, - callback=None, - ignore_queue=False, - ): + data: Any, + to: Optional[Union[str, List[str]]] = None, + room: Optional[Union[str, List[str]]] = None, + skip_sid: Optional[Union[str, List[str]]] = None, + namespace: Optional[str] = None, + callback: Optional[Callable[..., Any]] = None, + ignore_queue: bool = False, + ) -> None: """Send a message to one or more connected clients. This function emits an event with the name ``'message'``. Use @@ -265,14 +280,14 @@ class AsyncServer(base_server.BaseServer): async def call( self, - event, - data=None, - to=None, - sid=None, - namespace=None, - timeout=60, - ignore_queue=False, - ): + event: str, + data: Optional[Any] = None, + to: Optional[str] = None, + sid: Optional[str] = None, + namespace: Optional[str] = None, + timeout: float = 60, + ignore_queue: bool = False, + ) -> Any: """Emit a custom event to a client and wait for the response. This method issues an emit with a callback and waits for the callback @@ -319,7 +334,7 @@ class AsyncServer(base_server.BaseServer): callback_event = self.eio.create_event() callback_args = [] - def event_callback(*args): + def event_callback(*args: Any) -> None: callback_args.append(args) callback_event.set() @@ -343,7 +358,7 @@ class AsyncServer(base_server.BaseServer): else None ) - async def enter_room(self, sid, room, namespace=None): + async def enter_room(self, sid: str, room: str, namespace: Optional[str] = None) -> None: """Enter a room. This function adds the client to a room. The :func:`emit` and @@ -361,7 +376,7 @@ class AsyncServer(base_server.BaseServer): self.logger.info("%s is entering room %s [%s]", sid, room, namespace) await self.manager.enter_room(sid, namespace, room) - async def leave_room(self, sid, room, namespace=None): + async def leave_room(self, sid: str, room: str, namespace: Optional[str] = None) -> None: """Leave a room. This function removes the client from a room. @@ -377,7 +392,7 @@ class AsyncServer(base_server.BaseServer): self.logger.info("%s is leaving room %s [%s]", sid, room, namespace) await self.manager.leave_room(sid, namespace, room) - async def close_room(self, room, namespace=None): + async def close_room(self, room: str, namespace: Optional[str] = None) -> None: """Close a room. This function removes all the clients from the given room. @@ -392,7 +407,7 @@ class AsyncServer(base_server.BaseServer): self.logger.info("room %s is closing [%s]", room, namespace) await self.manager.close_room(room, namespace) - async def get_session(self, sid, namespace=None): + async def get_session(self, sid: str, namespace: Optional[str] = None) -> Dict[str, Any]: """Return the user session for a client. :param sid: The session id of the client. @@ -408,7 +423,7 @@ class AsyncServer(base_server.BaseServer): eio_session = await self.eio.get_session(eio_sid) return eio_session.setdefault(namespace, {}) - async def save_session(self, sid, session, namespace=None): + async def save_session(self, sid: str, session: Dict[str, Any], namespace: Optional[str] = None) -> None: """Store the user session for a client. :param sid: The session id of the client. @@ -421,7 +436,7 @@ class AsyncServer(base_server.BaseServer): eio_session = await self.eio.get_session(eio_sid) eio_session[namespace] = session - def session(self, sid, namespace=None): + def session(self, sid: str, namespace: Optional[str] = None) -> AsyncContextManager[Dict[str, Any]]: """Return the user session for a client with context manager syntax. :param sid: The session id of the client. @@ -446,26 +461,26 @@ class AsyncServer(base_server.BaseServer): """ class _session_context_manager: - def __init__(self, server, sid, namespace): + def __init__(self, server: "AsyncServer", sid: str, namespace: Optional[str]): self.server = server self.sid = sid self.namespace = namespace - self.session = None + self.session: Optional[Dict[str, Any]] = None - async def __aenter__(self): + async def __aenter__(self) -> Dict[str, Any]: self.session = await self.server.get_session( sid, namespace=self.namespace ) return self.session - async def __aexit__(self, *args): + async def __aexit__(self, *args: Any) -> None: await self.server.save_session( - sid, self.session, namespace=self.namespace + sid, self.session or {}, namespace=self.namespace ) return _session_context_manager(self, sid, namespace) - async def disconnect(self, sid, namespace=None, ignore_queue=False): + async def disconnect(self, sid: str, namespace: Optional[str] = None, ignore_queue: bool = False) -> None: """Disconnect a client. :param sid: Session ID of the client. @@ -495,7 +510,7 @@ class AsyncServer(base_server.BaseServer): ) await self.manager.disconnect(sid, namespace=namespace, ignore_queue=True) - async def shutdown(self): + async def shutdown(self) -> None: """Stop Socket.IO background tasks. This method stops all background activity initiated by the Socket.IO @@ -504,7 +519,7 @@ class AsyncServer(base_server.BaseServer): self.logger.info("Socket.IO is shutting down") await self.eio.shutdown() - async def handle_request(self, *args, **kwargs): + async def handle_request(self, *args: Any, **kwargs: Any) -> Any: """Handle an HTTP request from the client. This is the entry point of the Socket.IO application. This function @@ -514,7 +529,7 @@ class AsyncServer(base_server.BaseServer): """ return await self.eio.handle_request(*args, **kwargs) - def start_background_task(self, target, *args, **kwargs): + def start_background_task(self, target: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any) -> asyncio.Task[Any]: """Start a background task using the appropriate async model. This is a utility function that applications can use to start a @@ -529,7 +544,7 @@ class AsyncServer(base_server.BaseServer): """ return self.eio.start_background_task(target, *args, **kwargs) - async def sleep(self, seconds=0): + async def sleep(self, seconds: float = 0) -> None: """Sleep for the requested amount of time using the appropriate async model. @@ -543,13 +558,13 @@ class AsyncServer(base_server.BaseServer): def instrument( self, - auth=None, - mode="development", - read_only=False, - server_id=None, - namespace="/admin", - server_stats_interval=2, - ): + auth: Optional[Union[bool, Dict[str, Any], List[Dict[str, Any]], Callable[[Dict[str, Any]], bool]]] = None, + mode: str = "development", + read_only: bool = False, + server_id: Optional[str] = None, + namespace: str = "/admin", + server_stats_interval: int = 2, + ) -> "InstrumentedAsyncServer": """Instrument the Socket.IO server for monitoring with the `Socket.IO Admin UI `_. @@ -592,7 +607,7 @@ class AsyncServer(base_server.BaseServer): server_stats_interval=server_stats_interval, ) - async def _send_packet(self, eio_sid, pkt): + async def _send_packet(self, eio_sid: str, pkt: packet.Packet) -> None: """Send a Socket.IO packet to a client.""" encoded_packet = pkt.encode() if isinstance(encoded_packet, list): @@ -601,11 +616,11 @@ class AsyncServer(base_server.BaseServer): else: await self.eio.send(eio_sid, encoded_packet) - async def _send_eio_packet(self, eio_sid, eio_pkt): + async def _send_eio_packet(self, eio_sid: str, eio_pkt: Any) -> None: """Send a raw Engine.IO packet to a client.""" await self.eio.send_packet(eio_sid, eio_pkt) - async def _handle_connect(self, eio_sid, namespace, data): + async def _handle_connect(self, eio_sid: str, namespace: Optional[str], data: Optional[Any]) -> None: """Handle a client connection request.""" namespace = namespace or "/" sid = None @@ -672,7 +687,7 @@ class AsyncServer(base_server.BaseServer): self.packet_class(packet.CONNECT, {"sid": sid}, namespace=namespace), ) - async def _handle_disconnect(self, eio_sid, namespace, reason=None): + async def _handle_disconnect(self, eio_sid: str, namespace: Optional[str], reason: Optional[str] = None) -> None: """Handle a client disconnect.""" namespace = namespace or "/" sid = self.manager.sid_from_eio_sid(eio_sid, namespace) @@ -684,10 +699,12 @@ class AsyncServer(base_server.BaseServer): ) await self.manager.disconnect(sid, namespace, ignore_queue=True) - async def _handle_event(self, eio_sid, namespace, id, data): + async def _handle_event(self, eio_sid: str, namespace: Optional[str], id: Optional[int], data: Any) -> None: """Handle an incoming client event.""" namespace = namespace or "/" sid = self.manager.sid_from_eio_sid(eio_sid, namespace) + if sid is None: + return self.logger.info('received event "%s" from %s [%s]', data[0], sid, namespace) if not self.manager.is_connected(sid, namespace): self.logger.warning("%s is not connected to namespace %s", sid, namespace) @@ -701,7 +718,15 @@ class AsyncServer(base_server.BaseServer): else: await self._handle_event_internal(self, sid, eio_sid, data, namespace, id) - async def _handle_event_internal(self, server, sid, eio_sid, data, namespace, id): + async def _handle_event_internal( + self, + server: "AsyncServer", + sid: str, + eio_sid: str, + data: Any, + namespace: Optional[str], + id: Optional[int], + ) -> None: r = await server._trigger_event(data[0], namespace, sid, *data[1:]) if r != self.not_handled and id is not None: # send ACK packet with the response returned by the handler @@ -717,14 +742,14 @@ class AsyncServer(base_server.BaseServer): self.packet_class(packet.ACK, namespace=namespace, id=id, data=data), ) - async def _handle_ack(self, eio_sid, namespace, id, data): + async def _handle_ack(self, eio_sid: str, namespace: Optional[str], id: Optional[int], data: Any) -> None: """Handle ACK packets from the client.""" namespace = namespace or "/" sid = self.manager.sid_from_eio_sid(eio_sid, namespace) self.logger.info("received ack from %s [%s]", sid, namespace) await self.manager.trigger_callback(sid, id, data) - async def _trigger_event(self, event, namespace, *args): + async def _trigger_event(self, event: str, namespace: Optional[str], *args: Any) -> Any: """Invoke an application event handler.""" # first see if we have an explicit handler for the event handler, args = self._get_event_handler(event, namespace, args) @@ -757,14 +782,14 @@ class AsyncServer(base_server.BaseServer): return await handler.trigger_event(event, *args) return self.not_handled - async def _handle_eio_connect(self, eio_sid, environ): + async def _handle_eio_connect(self, eio_sid: str, environ: Dict[str, Any]) -> None: # type: ignore[override] """Handle the Engine.IO connection event.""" if not self.manager_initialized: self.manager_initialized = True self.manager.initialize() self.environ[eio_sid] = environ - async def _handle_eio_message(self, eio_sid, data): + async def _handle_eio_message(self, eio_sid: str, data: Union[bytes, str]) -> None: # type: ignore[override] """Dispatch Engine.IO messages.""" if eio_sid in self._binary_packet: pkt = self._binary_packet[eio_sid] @@ -796,12 +821,12 @@ class AsyncServer(base_server.BaseServer): else: raise ValueError("Unknown packet type.") - async def _handle_eio_disconnect(self, eio_sid, reason): + async def _handle_eio_disconnect(self, eio_sid: str, reason: str) -> None: # type: ignore[override] """Handle Engine.IO disconnect event.""" for n in list(self.manager.get_namespaces()).copy(): await self._handle_disconnect(eio_sid, n, reason) if eio_sid in self.environ: del self.environ[eio_sid] - def _engineio_server_class(self) -> engineio.AsyncServer: + def _engineio_server_class(self) -> Any: return engineio.AsyncServer diff --git a/src/socketio/base_namespace.py b/src/socketio/base_namespace.py index bf0e13d..2c17147 100644 --- a/src/socketio/base_namespace.py +++ b/src/socketio/base_namespace.py @@ -1,20 +1,23 @@ +from typing import Any, Optional + + class BaseNamespace: - def __init__(self, namespace=None): - self.namespace = namespace or "/" + def __init__(self, namespace: Optional[str] = None): + self.namespace: str = namespace or "/" - def is_asyncio_based(self): + def is_asyncio_based(self) -> bool: return False class BaseServerNamespace(BaseNamespace): - def __init__(self, namespace=None): + def __init__(self, namespace: Optional[str] = None): super().__init__(namespace=namespace) - self.server = None + self.server: Any = None - def _set_server(self, server): + def _set_server(self, server: Any) -> None: self.server = server - def rooms(self, sid, namespace=None): + def rooms(self, sid: str, namespace: Optional[str] = None): """Return the rooms a client is in. The only difference with the :func:`socketio.Server.rooms` method is @@ -25,9 +28,9 @@ class BaseServerNamespace(BaseNamespace): class BaseClientNamespace(BaseNamespace): - def __init__(self, namespace=None): + def __init__(self, namespace: Optional[str] = None): super().__init__(namespace=namespace) - self.client = None + self.client: Any = None - def _set_client(self, client): + def _set_client(self, client: Any) -> None: self.client = client diff --git a/src/socketio/base_server.py b/src/socketio/base_server.py index 45083fe..d1ca34e 100644 --- a/src/socketio/base_server.py +++ b/src/socketio/base_server.py @@ -1,4 +1,6 @@ import logging +# pyright: reportMissingImports=false +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import engineio @@ -8,20 +10,20 @@ default_logger = logging.getLogger("socketio.server") class BaseServer: - reserved_events = ["connect", "disconnect"] - reason = engineio.Server.reason + reserved_events: List[str] = ["connect", "disconnect"] + reason = engineio.Server.reason # type: ignore[attr-defined] def __init__( self, - client_manager=None, - logger=False, - serializer="default", - json=None, - async_handlers=True, - always_connect=False, - namespaces=None, - **kwargs, - ): + client_manager: Optional[manager.Manager] = None, + logger: Union[bool, logging.Logger] = False, + serializer: Union[str, Any] = "default", + json: Optional[Any] = None, + async_handlers: bool = True, + always_connect: bool = False, + namespaces: Optional[Union[List[str], str]] = None, + **kwargs: Any, + ) -> None: engineio_options = kwargs engineio_logger = engineio_options.pop("engineio_logger", None) if engineio_logger is not None: @@ -35,20 +37,21 @@ class BaseServer: else: self.packet_class = serializer if json is not None: - self.packet_class.json = json + # packet_class is a class with a 'json' attribute at runtime + setattr(self.packet_class, "json", json) engineio_options["json"] = json engineio_options["async_handlers"] = False - self.eio = self._engineio_server_class()(**engineio_options) + self.eio: Any = self._engineio_server_class()(**engineio_options) self.eio.on("connect", self._handle_eio_connect) self.eio.on("message", self._handle_eio_message) self.eio.on("disconnect", self._handle_eio_disconnect) - self.environ = {} - self.handlers = {} - self.namespace_handlers = {} - self.not_handled = object() + self.environ: Dict[str, Dict[str, Any]] = {} + self.handlers: Dict[str, Dict[str, Callable[..., Any]]] = {} + self.namespace_handlers: Dict[str, base_namespace.BaseServerNamespace] = {} + self.not_handled: object = object() - self._binary_packet = {} + self._binary_packet: Dict[str, Any] = {} if not isinstance(logger, bool): self.logger = logger @@ -63,20 +66,25 @@ class BaseServer: if client_manager is None: client_manager = manager.Manager() - self.manager = client_manager + self.manager: Any = client_manager self.manager.set_server(self) self.manager_initialized = False self.async_handlers = async_handlers self.always_connect = always_connect - self.namespaces = namespaces or ["/"] + self.namespaces: Union[List[str], str] = namespaces or ["/"] self.async_mode = self.eio.async_mode - def is_asyncio_based(self): + def is_asyncio_based(self) -> bool: return False - def on(self, event, handler=None, namespace=None): + def on( + self, + event: str, + handler: Optional[Callable[..., Any]] = None, + namespace: Optional[str] = None, + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Register an event handler. :param event: The event name. It can be any string. The event names @@ -129,7 +137,7 @@ class BaseServer: """ namespace = namespace or "/" - def set_handler(handler): + def set_handler(handler: Callable[..., Any]) -> Callable[..., Any]: if namespace not in self.handlers: self.handlers[namespace] = {} self.handlers[namespace][event] = handler @@ -138,8 +146,9 @@ class BaseServer: if handler is None: return set_handler set_handler(handler) + return set_handler - def event(self, *args, **kwargs): + def event(self, *args: Any, **kwargs: Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Decorator to register an event handler. This is a simplified version of the ``on()`` method that takes the @@ -169,26 +178,31 @@ class BaseServer: return self.on(args[0].__name__)(args[0]) # the decorator was invoked with arguments - def set_handler(handler): + def set_handler(handler: Callable[..., Any]) -> Callable[..., Any]: return self.on(handler.__name__, *args, **kwargs)(handler) return set_handler - def register_namespace(self, namespace_handler): + def register_namespace(self, namespace_handler: base_namespace.BaseServerNamespace) -> None: """Register a namespace handler object. :param namespace_handler: An instance of a :class:`Namespace` subclass that handles all the event traffic for a namespace. """ - if not isinstance(namespace_handler, base_namespace.BaseServerNamespace): + if not isinstance(namespace_handler, base_namespace.BaseServerNamespace): # type: ignore[redundant-expr] raise ValueError("Not a namespace instance") if self.is_asyncio_based() != namespace_handler.is_asyncio_based(): raise ValueError("Not a valid namespace class for this server") namespace_handler._set_server(self) self.namespace_handlers[namespace_handler.namespace] = namespace_handler + namespace_handler._set_server(self) # type: ignore[misc] + ns: str = str(namespace_handler.namespace or "/") + self.namespace_handlers[ns] = namespace_handler + def rooms(self, sid, namespace=None): + def rooms(self, sid: str, namespace: Optional[str] = None) -> List[str]: """Return the rooms a client is in. :param sid: Session ID of the client. @@ -198,7 +212,7 @@ class BaseServer: namespace = namespace or "/" return self.manager.get_rooms(sid, namespace) - def transport(self, sid, namespace=None): + def transport(self, sid: str, namespace: Optional[str] = None) -> str: """Return the name of the transport used by the client. The two possible values returned by this function are ``'polling'`` @@ -211,7 +225,7 @@ class BaseServer: eio_sid = self.manager.eio_sid_from_sid(sid, namespace or "/") return self.eio.transport(eio_sid) - def get_environ(self, sid, namespace=None): + def get_environ(self, sid: str, namespace: Optional[str] = None) -> Optional[Dict[str, Any]]: """Return the WSGI environ dictionary for a client. :param sid: The session of the client. @@ -219,9 +233,13 @@ class BaseServer: the default namespace is used. """ eio_sid = self.manager.eio_sid_from_sid(sid, namespace or "/") + if eio_sid is None: + return None return self.environ.get(eio_sid) - def _get_event_handler(self, event, namespace, args): + def _get_event_handler( + self, event: str, namespace: Optional[str], args: Tuple[Any, ...] + ) -> Tuple[Optional[Callable[..., Any]], Tuple[Any, ...]]: # Return the appropriate application event handler # # Resolution priority: @@ -229,7 +247,7 @@ class BaseServer: # - self.handlers[namespace]["*"] # - self.handlers["*"][event] # - self.handlers["*"]["*"] - handler = None + handler: Optional[Callable[..., Any]] = None if namespace in self.handlers: if event in self.handlers[namespace]: handler = self.handlers[namespace][event] @@ -245,13 +263,15 @@ class BaseServer: args = (event, namespace, *args) return handler, args - def _get_namespace_handler(self, namespace, args): + def _get_namespace_handler( + self, namespace: Optional[str], args: Tuple[Any, ...] + ) -> Tuple[Optional[base_namespace.BaseServerNamespace], Tuple[Any, ...]]: # Return the appropriate application event handler. # # Resolution priority: # - self.namespace_handlers[namespace] # - self.namespace_handlers["*"] - handler = None + handler: Optional[base_namespace.BaseServerNamespace] = None if namespace in self.namespace_handlers: handler = self.namespace_handlers[namespace] if handler is None and "*" in self.namespace_handlers: @@ -259,14 +279,14 @@ class BaseServer: args = (namespace, *args) return handler, args - def _handle_eio_connect(self): # pragma: no cover + def _handle_eio_connect(self) -> None: # pragma: no cover raise NotImplementedError - def _handle_eio_message(self, data): # pragma: no cover + def _handle_eio_message(self, data: Any) -> None: # pragma: no cover raise NotImplementedError - def _handle_eio_disconnect(self): # pragma: no cover + def _handle_eio_disconnect(self) -> None: # pragma: no cover raise NotImplementedError - def _engineio_server_class(self): # pragma: no cover + def _engineio_server_class(self) -> Any: # pragma: no cover raise NotImplementedError("Must be implemented in subclasses") diff --git a/src/socketio/tornado.py b/src/socketio/tornado.py index daf243d..c8d6e35 100644 --- a/src/socketio/tornado.py +++ b/src/socketio/tornado.py @@ -1,10 +1,17 @@ +from typing import Any + +# pyright: reportMissingImports=false, reportUnknownVariableType=false, reportUnknownMemberType=false + try: from engineio.async_drivers.tornado import ( get_tornado_handler as get_engineio_handler, ) except ImportError: # pragma: no cover - get_engineio_handler = None + get_engineio_handler = None # type: ignore[assignment] -def get_tornado_handler(socketio_server): # pragma: no cover - return get_engineio_handler(socketio_server.eio) +def get_tornado_handler(socketio_server: Any) -> Any: # pragma: no cover + # engineio handler factory expects an Engine.IO server instance + if get_engineio_handler is None: + raise RuntimeError("Tornado async driver is not available") + return get_engineio_handler(socketio_server.eio) # type: ignore[operator]