Browse Source

fix: type hints

pull/1486/head
Konstantin Ponomarev 1 week ago
parent
commit
c66b28270a
  1. 2
      src/socketio/__init__.py
  2. 167
      src/socketio/async_server.py
  3. 23
      src/socketio/base_namespace.py
  4. 94
      src/socketio/base_server.py
  5. 13
      src/socketio/tornado.py

2
src/socketio/__init__.py

@ -17,6 +17,7 @@ from .redis_manager import RedisManager
from .server import Server from .server import Server
from .simple_client import SimpleClient from .simple_client import SimpleClient
from .tornado import get_tornado_handler from .tornado import get_tornado_handler
from .router import RouterSocketIO
from .zmq_manager import ZmqManager from .zmq_manager import ZmqManager
__all__ = [ __all__ = [
@ -43,4 +44,5 @@ __all__ = [
"WSGIApp", "WSGIApp",
"ZmqManager", "ZmqManager",
"get_tornado_handler", "get_tornado_handler",
"RouterSocketIO",
] ]

167
src/socketio/async_server.py

@ -1,16 +1,30 @@
import asyncio import asyncio
# pyright: reportMissingImports=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownParameterType=false
from typing import Any, AsyncContextManager, Callable, Dict, List, Optional, Set, Union, TYPE_CHECKING, Coroutine
import engineio import engineio
from . import async_manager, base_server, exceptions, packet from . import async_manager, base_server, exceptions, packet
if TYPE_CHECKING: # pragma: no cover
from .async_admin import InstrumentedAsyncServer
# this set is used to keep references to background tasks to prevent them from # this set is used to keep references to background tasks to prevent them from
# being garbage collected mid-execution. Solution taken from # being garbage collected mid-execution. Solution taken from
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
task_reference_holder = set() task_reference_holder: Set[asyncio.Task[Any]] = set()
class AsyncServer(base_server.BaseServer): class AsyncServer(base_server.BaseServer):
# Attribute type hints to aid static type checkers
manager: async_manager.AsyncManager
eio: Any
packet_class: Any
handlers: Dict[str, Dict[str, Any]]
namespace_handlers: Dict[str, Any]
namespaces: Union[List[str], str]
environ: Dict[str, Any]
_binary_packet: Dict[str, Any]
"""A Socket.IO server for asyncio. """A Socket.IO server for asyncio.
This class implements a fully compliant Socket.IO web server with support This class implements a fully compliant Socket.IO web server with support
@ -111,13 +125,13 @@ class AsyncServer(base_server.BaseServer):
def __init__( def __init__(
self, self,
client_manager=None, client_manager: Optional[async_manager.AsyncManager] = None,
logger=False, logger: Union[bool, Any] = False,
json=None, json: Optional[Any] = None,
async_handlers=True, async_handlers: bool = True,
namespaces=None, namespaces: Optional[Union[List[str], str]] = None,
**kwargs, **kwargs: Any,
): ) -> None:
if client_manager is None: if client_manager is None:
client_manager = async_manager.AsyncManager() client_manager = async_manager.AsyncManager()
super().__init__( super().__init__(
@ -128,25 +142,26 @@ class AsyncServer(base_server.BaseServer):
namespaces=namespaces, namespaces=namespaces,
**kwargs, **kwargs,
) )
# attributes are provided by the base class; runtime types are ensured
def is_asyncio_based(self): def is_asyncio_based(self): # type: ignore[override]
return True return True
def attach(self, app, socketio_path="socket.io"): def attach(self, app: Any, socketio_path: str = "socket.io") -> None:
"""Attach the Socket.IO server to an application.""" """Attach the Socket.IO server to an application."""
self.eio.attach(app, socketio_path) self.eio.attach(app, socketio_path)
async def emit( async def emit(
self, self,
event, event: str,
data=None, data: Optional[Any] = None,
to=None, to: Optional[Union[str, List[str]]] = None,
room=None, room: Optional[Union[str, List[str]]] = None,
skip_sid=None, skip_sid: Optional[Union[str, List[str]]] = None,
namespace=None, namespace: Optional[str] = None,
callback=None, callback: Optional[Callable[..., Any]] = None,
ignore_queue=False, ignore_queue: bool = False,
): ) -> None:
"""Emit a custom event to one or more connected clients. """Emit a custom event to one or more connected clients.
:param event: The event name. It can be any string. The event names :param event: The event name. It can be any string. The event names
@ -207,14 +222,14 @@ class AsyncServer(base_server.BaseServer):
async def send( async def send(
self, self,
data, data: Any,
to=None, to: Optional[Union[str, List[str]]] = None,
room=None, room: Optional[Union[str, List[str]]] = None,
skip_sid=None, skip_sid: Optional[Union[str, List[str]]] = None,
namespace=None, namespace: Optional[str] = None,
callback=None, callback: Optional[Callable[..., Any]] = None,
ignore_queue=False, ignore_queue: bool = False,
): ) -> None:
"""Send a message to one or more connected clients. """Send a message to one or more connected clients.
This function emits an event with the name ``'message'``. Use This function emits an event with the name ``'message'``. Use
@ -265,14 +280,14 @@ class AsyncServer(base_server.BaseServer):
async def call( async def call(
self, self,
event, event: str,
data=None, data: Optional[Any] = None,
to=None, to: Optional[str] = None,
sid=None, sid: Optional[str] = None,
namespace=None, namespace: Optional[str] = None,
timeout=60, timeout: float = 60,
ignore_queue=False, ignore_queue: bool = False,
): ) -> Any:
"""Emit a custom event to a client and wait for the response. """Emit a custom event to a client and wait for the response.
This method issues an emit with a callback and waits for the callback This method issues an emit with a callback and waits for the callback
@ -319,7 +334,7 @@ class AsyncServer(base_server.BaseServer):
callback_event = self.eio.create_event() callback_event = self.eio.create_event()
callback_args = [] callback_args = []
def event_callback(*args): def event_callback(*args: Any) -> None:
callback_args.append(args) callback_args.append(args)
callback_event.set() callback_event.set()
@ -343,7 +358,7 @@ class AsyncServer(base_server.BaseServer):
else None else None
) )
async def enter_room(self, sid, room, namespace=None): async def enter_room(self, sid: str, room: str, namespace: Optional[str] = None) -> None:
"""Enter a room. """Enter a room.
This function adds the client to a room. The :func:`emit` and This function adds the client to a room. The :func:`emit` and
@ -361,7 +376,7 @@ class AsyncServer(base_server.BaseServer):
self.logger.info("%s is entering room %s [%s]", sid, room, namespace) self.logger.info("%s is entering room %s [%s]", sid, room, namespace)
await self.manager.enter_room(sid, namespace, room) await self.manager.enter_room(sid, namespace, room)
async def leave_room(self, sid, room, namespace=None): async def leave_room(self, sid: str, room: str, namespace: Optional[str] = None) -> None:
"""Leave a room. """Leave a room.
This function removes the client from a room. This function removes the client from a room.
@ -377,7 +392,7 @@ class AsyncServer(base_server.BaseServer):
self.logger.info("%s is leaving room %s [%s]", sid, room, namespace) self.logger.info("%s is leaving room %s [%s]", sid, room, namespace)
await self.manager.leave_room(sid, namespace, room) await self.manager.leave_room(sid, namespace, room)
async def close_room(self, room, namespace=None): async def close_room(self, room: str, namespace: Optional[str] = None) -> None:
"""Close a room. """Close a room.
This function removes all the clients from the given room. This function removes all the clients from the given room.
@ -392,7 +407,7 @@ class AsyncServer(base_server.BaseServer):
self.logger.info("room %s is closing [%s]", room, namespace) self.logger.info("room %s is closing [%s]", room, namespace)
await self.manager.close_room(room, namespace) await self.manager.close_room(room, namespace)
async def get_session(self, sid, namespace=None): async def get_session(self, sid: str, namespace: Optional[str] = None) -> Dict[str, Any]:
"""Return the user session for a client. """Return the user session for a client.
:param sid: The session id of the client. :param sid: The session id of the client.
@ -408,7 +423,7 @@ class AsyncServer(base_server.BaseServer):
eio_session = await self.eio.get_session(eio_sid) eio_session = await self.eio.get_session(eio_sid)
return eio_session.setdefault(namespace, {}) return eio_session.setdefault(namespace, {})
async def save_session(self, sid, session, namespace=None): async def save_session(self, sid: str, session: Dict[str, Any], namespace: Optional[str] = None) -> None:
"""Store the user session for a client. """Store the user session for a client.
:param sid: The session id of the client. :param sid: The session id of the client.
@ -421,7 +436,7 @@ class AsyncServer(base_server.BaseServer):
eio_session = await self.eio.get_session(eio_sid) eio_session = await self.eio.get_session(eio_sid)
eio_session[namespace] = session eio_session[namespace] = session
def session(self, sid, namespace=None): def session(self, sid: str, namespace: Optional[str] = None) -> AsyncContextManager[Dict[str, Any]]:
"""Return the user session for a client with context manager syntax. """Return the user session for a client with context manager syntax.
:param sid: The session id of the client. :param sid: The session id of the client.
@ -446,26 +461,26 @@ class AsyncServer(base_server.BaseServer):
""" """
class _session_context_manager: class _session_context_manager:
def __init__(self, server, sid, namespace): def __init__(self, server: "AsyncServer", sid: str, namespace: Optional[str]):
self.server = server self.server = server
self.sid = sid self.sid = sid
self.namespace = namespace self.namespace = namespace
self.session = None self.session: Optional[Dict[str, Any]] = None
async def __aenter__(self): async def __aenter__(self) -> Dict[str, Any]:
self.session = await self.server.get_session( self.session = await self.server.get_session(
sid, namespace=self.namespace sid, namespace=self.namespace
) )
return self.session return self.session
async def __aexit__(self, *args): async def __aexit__(self, *args: Any) -> None:
await self.server.save_session( await self.server.save_session(
sid, self.session, namespace=self.namespace sid, self.session or {}, namespace=self.namespace
) )
return _session_context_manager(self, sid, namespace) return _session_context_manager(self, sid, namespace)
async def disconnect(self, sid, namespace=None, ignore_queue=False): async def disconnect(self, sid: str, namespace: Optional[str] = None, ignore_queue: bool = False) -> None:
"""Disconnect a client. """Disconnect a client.
:param sid: Session ID of the client. :param sid: Session ID of the client.
@ -495,7 +510,7 @@ class AsyncServer(base_server.BaseServer):
) )
await self.manager.disconnect(sid, namespace=namespace, ignore_queue=True) await self.manager.disconnect(sid, namespace=namespace, ignore_queue=True)
async def shutdown(self): async def shutdown(self) -> None:
"""Stop Socket.IO background tasks. """Stop Socket.IO background tasks.
This method stops all background activity initiated by the Socket.IO This method stops all background activity initiated by the Socket.IO
@ -504,7 +519,7 @@ class AsyncServer(base_server.BaseServer):
self.logger.info("Socket.IO is shutting down") self.logger.info("Socket.IO is shutting down")
await self.eio.shutdown() await self.eio.shutdown()
async def handle_request(self, *args, **kwargs): async def handle_request(self, *args: Any, **kwargs: Any) -> Any:
"""Handle an HTTP request from the client. """Handle an HTTP request from the client.
This is the entry point of the Socket.IO application. This function This is the entry point of the Socket.IO application. This function
@ -514,7 +529,7 @@ class AsyncServer(base_server.BaseServer):
""" """
return await self.eio.handle_request(*args, **kwargs) return await self.eio.handle_request(*args, **kwargs)
def start_background_task(self, target, *args, **kwargs): def start_background_task(self, target: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any) -> asyncio.Task[Any]:
"""Start a background task using the appropriate async model. """Start a background task using the appropriate async model.
This is a utility function that applications can use to start a This is a utility function that applications can use to start a
@ -529,7 +544,7 @@ class AsyncServer(base_server.BaseServer):
""" """
return self.eio.start_background_task(target, *args, **kwargs) return self.eio.start_background_task(target, *args, **kwargs)
async def sleep(self, seconds=0): async def sleep(self, seconds: float = 0) -> None:
"""Sleep for the requested amount of time using the appropriate async """Sleep for the requested amount of time using the appropriate async
model. model.
@ -543,13 +558,13 @@ class AsyncServer(base_server.BaseServer):
def instrument( def instrument(
self, self,
auth=None, auth: Optional[Union[bool, Dict[str, Any], List[Dict[str, Any]], Callable[[Dict[str, Any]], bool]]] = None,
mode="development", mode: str = "development",
read_only=False, read_only: bool = False,
server_id=None, server_id: Optional[str] = None,
namespace="/admin", namespace: str = "/admin",
server_stats_interval=2, server_stats_interval: int = 2,
): ) -> "InstrumentedAsyncServer":
"""Instrument the Socket.IO server for monitoring with the `Socket.IO """Instrument the Socket.IO server for monitoring with the `Socket.IO
Admin UI <https://socket.io/docs/v4/admin-ui/>`_. Admin UI <https://socket.io/docs/v4/admin-ui/>`_.
@ -592,7 +607,7 @@ class AsyncServer(base_server.BaseServer):
server_stats_interval=server_stats_interval, server_stats_interval=server_stats_interval,
) )
async def _send_packet(self, eio_sid, pkt): async def _send_packet(self, eio_sid: str, pkt: packet.Packet) -> None:
"""Send a Socket.IO packet to a client.""" """Send a Socket.IO packet to a client."""
encoded_packet = pkt.encode() encoded_packet = pkt.encode()
if isinstance(encoded_packet, list): if isinstance(encoded_packet, list):
@ -601,11 +616,11 @@ class AsyncServer(base_server.BaseServer):
else: else:
await self.eio.send(eio_sid, encoded_packet) await self.eio.send(eio_sid, encoded_packet)
async def _send_eio_packet(self, eio_sid, eio_pkt): async def _send_eio_packet(self, eio_sid: str, eio_pkt: Any) -> None:
"""Send a raw Engine.IO packet to a client.""" """Send a raw Engine.IO packet to a client."""
await self.eio.send_packet(eio_sid, eio_pkt) await self.eio.send_packet(eio_sid, eio_pkt)
async def _handle_connect(self, eio_sid, namespace, data): async def _handle_connect(self, eio_sid: str, namespace: Optional[str], data: Optional[Any]) -> None:
"""Handle a client connection request.""" """Handle a client connection request."""
namespace = namespace or "/" namespace = namespace or "/"
sid = None sid = None
@ -672,7 +687,7 @@ class AsyncServer(base_server.BaseServer):
self.packet_class(packet.CONNECT, {"sid": sid}, namespace=namespace), self.packet_class(packet.CONNECT, {"sid": sid}, namespace=namespace),
) )
async def _handle_disconnect(self, eio_sid, namespace, reason=None): async def _handle_disconnect(self, eio_sid: str, namespace: Optional[str], reason: Optional[str] = None) -> None:
"""Handle a client disconnect.""" """Handle a client disconnect."""
namespace = namespace or "/" namespace = namespace or "/"
sid = self.manager.sid_from_eio_sid(eio_sid, namespace) sid = self.manager.sid_from_eio_sid(eio_sid, namespace)
@ -684,10 +699,12 @@ class AsyncServer(base_server.BaseServer):
) )
await self.manager.disconnect(sid, namespace, ignore_queue=True) await self.manager.disconnect(sid, namespace, ignore_queue=True)
async def _handle_event(self, eio_sid, namespace, id, data): async def _handle_event(self, eio_sid: str, namespace: Optional[str], id: Optional[int], data: Any) -> None:
"""Handle an incoming client event.""" """Handle an incoming client event."""
namespace = namespace or "/" namespace = namespace or "/"
sid = self.manager.sid_from_eio_sid(eio_sid, namespace) sid = self.manager.sid_from_eio_sid(eio_sid, namespace)
if sid is None:
return
self.logger.info('received event "%s" from %s [%s]', data[0], sid, namespace) self.logger.info('received event "%s" from %s [%s]', data[0], sid, namespace)
if not self.manager.is_connected(sid, namespace): if not self.manager.is_connected(sid, namespace):
self.logger.warning("%s is not connected to namespace %s", sid, namespace) self.logger.warning("%s is not connected to namespace %s", sid, namespace)
@ -701,7 +718,15 @@ class AsyncServer(base_server.BaseServer):
else: else:
await self._handle_event_internal(self, sid, eio_sid, data, namespace, id) await self._handle_event_internal(self, sid, eio_sid, data, namespace, id)
async def _handle_event_internal(self, server, sid, eio_sid, data, namespace, id): async def _handle_event_internal(
self,
server: "AsyncServer",
sid: str,
eio_sid: str,
data: Any,
namespace: Optional[str],
id: Optional[int],
) -> None:
r = await server._trigger_event(data[0], namespace, sid, *data[1:]) r = await server._trigger_event(data[0], namespace, sid, *data[1:])
if r != self.not_handled and id is not None: if r != self.not_handled and id is not None:
# send ACK packet with the response returned by the handler # send ACK packet with the response returned by the handler
@ -717,14 +742,14 @@ class AsyncServer(base_server.BaseServer):
self.packet_class(packet.ACK, namespace=namespace, id=id, data=data), self.packet_class(packet.ACK, namespace=namespace, id=id, data=data),
) )
async def _handle_ack(self, eio_sid, namespace, id, data): async def _handle_ack(self, eio_sid: str, namespace: Optional[str], id: Optional[int], data: Any) -> None:
"""Handle ACK packets from the client.""" """Handle ACK packets from the client."""
namespace = namespace or "/" namespace = namespace or "/"
sid = self.manager.sid_from_eio_sid(eio_sid, namespace) sid = self.manager.sid_from_eio_sid(eio_sid, namespace)
self.logger.info("received ack from %s [%s]", sid, namespace) self.logger.info("received ack from %s [%s]", sid, namespace)
await self.manager.trigger_callback(sid, id, data) await self.manager.trigger_callback(sid, id, data)
async def _trigger_event(self, event, namespace, *args): async def _trigger_event(self, event: str, namespace: Optional[str], *args: Any) -> Any:
"""Invoke an application event handler.""" """Invoke an application event handler."""
# first see if we have an explicit handler for the event # first see if we have an explicit handler for the event
handler, args = self._get_event_handler(event, namespace, args) handler, args = self._get_event_handler(event, namespace, args)
@ -757,14 +782,14 @@ class AsyncServer(base_server.BaseServer):
return await handler.trigger_event(event, *args) return await handler.trigger_event(event, *args)
return self.not_handled return self.not_handled
async def _handle_eio_connect(self, eio_sid, environ): async def _handle_eio_connect(self, eio_sid: str, environ: Dict[str, Any]) -> None: # type: ignore[override]
"""Handle the Engine.IO connection event.""" """Handle the Engine.IO connection event."""
if not self.manager_initialized: if not self.manager_initialized:
self.manager_initialized = True self.manager_initialized = True
self.manager.initialize() self.manager.initialize()
self.environ[eio_sid] = environ self.environ[eio_sid] = environ
async def _handle_eio_message(self, eio_sid, data): async def _handle_eio_message(self, eio_sid: str, data: Union[bytes, str]) -> None: # type: ignore[override]
"""Dispatch Engine.IO messages.""" """Dispatch Engine.IO messages."""
if eio_sid in self._binary_packet: if eio_sid in self._binary_packet:
pkt = self._binary_packet[eio_sid] pkt = self._binary_packet[eio_sid]
@ -796,12 +821,12 @@ class AsyncServer(base_server.BaseServer):
else: else:
raise ValueError("Unknown packet type.") raise ValueError("Unknown packet type.")
async def _handle_eio_disconnect(self, eio_sid, reason): async def _handle_eio_disconnect(self, eio_sid: str, reason: str) -> None: # type: ignore[override]
"""Handle Engine.IO disconnect event.""" """Handle Engine.IO disconnect event."""
for n in list(self.manager.get_namespaces()).copy(): for n in list(self.manager.get_namespaces()).copy():
await self._handle_disconnect(eio_sid, n, reason) await self._handle_disconnect(eio_sid, n, reason)
if eio_sid in self.environ: if eio_sid in self.environ:
del self.environ[eio_sid] del self.environ[eio_sid]
def _engineio_server_class(self) -> engineio.AsyncServer: def _engineio_server_class(self) -> Any:
return engineio.AsyncServer return engineio.AsyncServer

23
src/socketio/base_namespace.py

@ -1,20 +1,23 @@
from typing import Any, Optional
class BaseNamespace: class BaseNamespace:
def __init__(self, namespace=None): def __init__(self, namespace: Optional[str] = None):
self.namespace = namespace or "/" self.namespace: str = namespace or "/"
def is_asyncio_based(self): def is_asyncio_based(self) -> bool:
return False return False
class BaseServerNamespace(BaseNamespace): class BaseServerNamespace(BaseNamespace):
def __init__(self, namespace=None): def __init__(self, namespace: Optional[str] = None):
super().__init__(namespace=namespace) super().__init__(namespace=namespace)
self.server = None self.server: Any = None
def _set_server(self, server): def _set_server(self, server: Any) -> None:
self.server = server self.server = server
def rooms(self, sid, namespace=None): def rooms(self, sid: str, namespace: Optional[str] = None):
"""Return the rooms a client is in. """Return the rooms a client is in.
The only difference with the :func:`socketio.Server.rooms` method is The only difference with the :func:`socketio.Server.rooms` method is
@ -25,9 +28,9 @@ class BaseServerNamespace(BaseNamespace):
class BaseClientNamespace(BaseNamespace): class BaseClientNamespace(BaseNamespace):
def __init__(self, namespace=None): def __init__(self, namespace: Optional[str] = None):
super().__init__(namespace=namespace) super().__init__(namespace=namespace)
self.client = None self.client: Any = None
def _set_client(self, client): def _set_client(self, client: Any) -> None:
self.client = client self.client = client

94
src/socketio/base_server.py

@ -1,4 +1,6 @@
import logging import logging
# pyright: reportMissingImports=false
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import engineio import engineio
@ -8,20 +10,20 @@ default_logger = logging.getLogger("socketio.server")
class BaseServer: class BaseServer:
reserved_events = ["connect", "disconnect"] reserved_events: List[str] = ["connect", "disconnect"]
reason = engineio.Server.reason reason = engineio.Server.reason # type: ignore[attr-defined]
def __init__( def __init__(
self, self,
client_manager=None, client_manager: Optional[manager.Manager] = None,
logger=False, logger: Union[bool, logging.Logger] = False,
serializer="default", serializer: Union[str, Any] = "default",
json=None, json: Optional[Any] = None,
async_handlers=True, async_handlers: bool = True,
always_connect=False, always_connect: bool = False,
namespaces=None, namespaces: Optional[Union[List[str], str]] = None,
**kwargs, **kwargs: Any,
): ) -> None:
engineio_options = kwargs engineio_options = kwargs
engineio_logger = engineio_options.pop("engineio_logger", None) engineio_logger = engineio_options.pop("engineio_logger", None)
if engineio_logger is not None: if engineio_logger is not None:
@ -35,20 +37,21 @@ class BaseServer:
else: else:
self.packet_class = serializer self.packet_class = serializer
if json is not None: if json is not None:
self.packet_class.json = json # packet_class is a class with a 'json' attribute at runtime
setattr(self.packet_class, "json", json)
engineio_options["json"] = json engineio_options["json"] = json
engineio_options["async_handlers"] = False engineio_options["async_handlers"] = False
self.eio = self._engineio_server_class()(**engineio_options) self.eio: Any = self._engineio_server_class()(**engineio_options)
self.eio.on("connect", self._handle_eio_connect) self.eio.on("connect", self._handle_eio_connect)
self.eio.on("message", self._handle_eio_message) self.eio.on("message", self._handle_eio_message)
self.eio.on("disconnect", self._handle_eio_disconnect) self.eio.on("disconnect", self._handle_eio_disconnect)
self.environ = {} self.environ: Dict[str, Dict[str, Any]] = {}
self.handlers = {} self.handlers: Dict[str, Dict[str, Callable[..., Any]]] = {}
self.namespace_handlers = {} self.namespace_handlers: Dict[str, base_namespace.BaseServerNamespace] = {}
self.not_handled = object() self.not_handled: object = object()
self._binary_packet = {} self._binary_packet: Dict[str, Any] = {}
if not isinstance(logger, bool): if not isinstance(logger, bool):
self.logger = logger self.logger = logger
@ -63,20 +66,25 @@ class BaseServer:
if client_manager is None: if client_manager is None:
client_manager = manager.Manager() client_manager = manager.Manager()
self.manager = client_manager self.manager: Any = client_manager
self.manager.set_server(self) self.manager.set_server(self)
self.manager_initialized = False self.manager_initialized = False
self.async_handlers = async_handlers self.async_handlers = async_handlers
self.always_connect = always_connect self.always_connect = always_connect
self.namespaces = namespaces or ["/"] self.namespaces: Union[List[str], str] = namespaces or ["/"]
self.async_mode = self.eio.async_mode self.async_mode = self.eio.async_mode
def is_asyncio_based(self): def is_asyncio_based(self) -> bool:
return False return False
def on(self, event, handler=None, namespace=None): def on(
self,
event: str,
handler: Optional[Callable[..., Any]] = None,
namespace: Optional[str] = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Register an event handler. """Register an event handler.
:param event: The event name. It can be any string. The event names :param event: The event name. It can be any string. The event names
@ -129,7 +137,7 @@ class BaseServer:
""" """
namespace = namespace or "/" namespace = namespace or "/"
def set_handler(handler): def set_handler(handler: Callable[..., Any]) -> Callable[..., Any]:
if namespace not in self.handlers: if namespace not in self.handlers:
self.handlers[namespace] = {} self.handlers[namespace] = {}
self.handlers[namespace][event] = handler self.handlers[namespace][event] = handler
@ -138,8 +146,9 @@ class BaseServer:
if handler is None: if handler is None:
return set_handler return set_handler
set_handler(handler) set_handler(handler)
return set_handler
def event(self, *args, **kwargs): def event(self, *args: Any, **kwargs: Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Decorator to register an event handler. """Decorator to register an event handler.
This is a simplified version of the ``on()`` method that takes the This is a simplified version of the ``on()`` method that takes the
@ -169,26 +178,31 @@ class BaseServer:
return self.on(args[0].__name__)(args[0]) return self.on(args[0].__name__)(args[0])
# the decorator was invoked with arguments # the decorator was invoked with arguments
def set_handler(handler): def set_handler(handler: Callable[..., Any]) -> Callable[..., Any]:
return self.on(handler.__name__, *args, **kwargs)(handler) return self.on(handler.__name__, *args, **kwargs)(handler)
return set_handler return set_handler
def register_namespace(self, namespace_handler): def register_namespace(self, namespace_handler: base_namespace.BaseServerNamespace) -> None:
"""Register a namespace handler object. """Register a namespace handler object.
:param namespace_handler: An instance of a :class:`Namespace` :param namespace_handler: An instance of a :class:`Namespace`
subclass that handles all the event traffic subclass that handles all the event traffic
for a namespace. for a namespace.
""" """
if not isinstance(namespace_handler, base_namespace.BaseServerNamespace): if not isinstance(namespace_handler, base_namespace.BaseServerNamespace): # type: ignore[redundant-expr]
raise ValueError("Not a namespace instance") raise ValueError("Not a namespace instance")
if self.is_asyncio_based() != namespace_handler.is_asyncio_based(): if self.is_asyncio_based() != namespace_handler.is_asyncio_based():
raise ValueError("Not a valid namespace class for this server") raise ValueError("Not a valid namespace class for this server")
namespace_handler._set_server(self) namespace_handler._set_server(self)
self.namespace_handlers[namespace_handler.namespace] = namespace_handler self.namespace_handlers[namespace_handler.namespace] = namespace_handler
namespace_handler._set_server(self) # type: ignore[misc]
ns: str = str(namespace_handler.namespace or "/")
self.namespace_handlers[ns] = namespace_handler
def rooms(self, sid, namespace=None): def rooms(self, sid, namespace=None):
def rooms(self, sid: str, namespace: Optional[str] = None) -> List[str]:
"""Return the rooms a client is in. """Return the rooms a client is in.
:param sid: Session ID of the client. :param sid: Session ID of the client.
@ -198,7 +212,7 @@ class BaseServer:
namespace = namespace or "/" namespace = namespace or "/"
return self.manager.get_rooms(sid, namespace) return self.manager.get_rooms(sid, namespace)
def transport(self, sid, namespace=None): def transport(self, sid: str, namespace: Optional[str] = None) -> str:
"""Return the name of the transport used by the client. """Return the name of the transport used by the client.
The two possible values returned by this function are ``'polling'`` The two possible values returned by this function are ``'polling'``
@ -211,7 +225,7 @@ class BaseServer:
eio_sid = self.manager.eio_sid_from_sid(sid, namespace or "/") eio_sid = self.manager.eio_sid_from_sid(sid, namespace or "/")
return self.eio.transport(eio_sid) return self.eio.transport(eio_sid)
def get_environ(self, sid, namespace=None): def get_environ(self, sid: str, namespace: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""Return the WSGI environ dictionary for a client. """Return the WSGI environ dictionary for a client.
:param sid: The session of the client. :param sid: The session of the client.
@ -219,9 +233,13 @@ class BaseServer:
the default namespace is used. the default namespace is used.
""" """
eio_sid = self.manager.eio_sid_from_sid(sid, namespace or "/") eio_sid = self.manager.eio_sid_from_sid(sid, namespace or "/")
if eio_sid is None:
return None
return self.environ.get(eio_sid) return self.environ.get(eio_sid)
def _get_event_handler(self, event, namespace, args): def _get_event_handler(
self, event: str, namespace: Optional[str], args: Tuple[Any, ...]
) -> Tuple[Optional[Callable[..., Any]], Tuple[Any, ...]]:
# Return the appropriate application event handler # Return the appropriate application event handler
# #
# Resolution priority: # Resolution priority:
@ -229,7 +247,7 @@ class BaseServer:
# - self.handlers[namespace]["*"] # - self.handlers[namespace]["*"]
# - self.handlers["*"][event] # - self.handlers["*"][event]
# - self.handlers["*"]["*"] # - self.handlers["*"]["*"]
handler = None handler: Optional[Callable[..., Any]] = None
if namespace in self.handlers: if namespace in self.handlers:
if event in self.handlers[namespace]: if event in self.handlers[namespace]:
handler = self.handlers[namespace][event] handler = self.handlers[namespace][event]
@ -245,13 +263,15 @@ class BaseServer:
args = (event, namespace, *args) args = (event, namespace, *args)
return handler, args return handler, args
def _get_namespace_handler(self, namespace, args): def _get_namespace_handler(
self, namespace: Optional[str], args: Tuple[Any, ...]
) -> Tuple[Optional[base_namespace.BaseServerNamespace], Tuple[Any, ...]]:
# Return the appropriate application event handler. # Return the appropriate application event handler.
# #
# Resolution priority: # Resolution priority:
# - self.namespace_handlers[namespace] # - self.namespace_handlers[namespace]
# - self.namespace_handlers["*"] # - self.namespace_handlers["*"]
handler = None handler: Optional[base_namespace.BaseServerNamespace] = None
if namespace in self.namespace_handlers: if namespace in self.namespace_handlers:
handler = self.namespace_handlers[namespace] handler = self.namespace_handlers[namespace]
if handler is None and "*" in self.namespace_handlers: if handler is None and "*" in self.namespace_handlers:
@ -259,14 +279,14 @@ class BaseServer:
args = (namespace, *args) args = (namespace, *args)
return handler, args return handler, args
def _handle_eio_connect(self): # pragma: no cover def _handle_eio_connect(self) -> None: # pragma: no cover
raise NotImplementedError raise NotImplementedError
def _handle_eio_message(self, data): # pragma: no cover def _handle_eio_message(self, data: Any) -> None: # pragma: no cover
raise NotImplementedError raise NotImplementedError
def _handle_eio_disconnect(self): # pragma: no cover def _handle_eio_disconnect(self) -> None: # pragma: no cover
raise NotImplementedError raise NotImplementedError
def _engineio_server_class(self): # pragma: no cover def _engineio_server_class(self) -> Any: # pragma: no cover
raise NotImplementedError("Must be implemented in subclasses") raise NotImplementedError("Must be implemented in subclasses")

13
src/socketio/tornado.py

@ -1,10 +1,17 @@
from typing import Any
# pyright: reportMissingImports=false, reportUnknownVariableType=false, reportUnknownMemberType=false
try: try:
from engineio.async_drivers.tornado import ( from engineio.async_drivers.tornado import (
get_tornado_handler as get_engineio_handler, get_tornado_handler as get_engineio_handler,
) )
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
get_engineio_handler = None get_engineio_handler = None # type: ignore[assignment]
def get_tornado_handler(socketio_server): # pragma: no cover def get_tornado_handler(socketio_server: Any) -> Any: # pragma: no cover
return get_engineio_handler(socketio_server.eio) # engineio handler factory expects an Engine.IO server instance
if get_engineio_handler is None:
raise RuntimeError("Tornado async driver is not available")
return get_engineio_handler(socketio_server.eio) # type: ignore[operator]

Loading…
Cancel
Save