Browse Source

feat: DI, pydantic models input validations

pull/1486/head
Konstantin Ponomarev 1 week ago
parent
commit
f209ef4b27
  1. 1516
      poetry.lock
  2. 3
      pyproject.toml
  3. 6
      src/fastsio/__init__.py
  4. 126
      src/fastsio/async_server.py
  5. 12
      src/fastsio/types.py

1516
poetry.lock

File diff suppressed because it is too large

3
pyproject.toml

@ -33,13 +33,14 @@ docs = [
] ]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8" python = ">=3.9"
bidict = ">=0.21.0" bidict = ">=0.21.0"
python-engineio = "^4.12.2" python-engineio = "^4.12.2"
requests = { version = ">=2.21.0", optional = true } requests = { version = ">=2.21.0", optional = true }
websocket-client = { version = ">=0.54.0", optional = true } websocket-client = { version = ">=0.54.0", optional = true }
aiohttp = { version = ">=3.4", optional = true } aiohttp = { version = ">=3.4", optional = true }
sphinx = { version = "*", optional = true } sphinx = { version = "*", optional = true }
pydantic = "^2.11.7"
[tool.poetry.extras] [tool.poetry.extras]
client = ["requests", "websocket-client"] client = ["requests", "websocket-client"]

6
src/fastsio/__init__.py

@ -18,6 +18,7 @@ from .server import Server
from .simple_client import SimpleClient from .simple_client import SimpleClient
from .tornado import get_tornado_handler from .tornado import get_tornado_handler
from .router import RouterSIO from .router import RouterSIO
from .types import SocketID, Environ, Auth, Reason, Data
from .zmq_manager import ZmqManager from .zmq_manager import ZmqManager
__all__ = [ __all__ = [
@ -45,4 +46,9 @@ __all__ = [
"ZmqManager", "ZmqManager",
"get_tornado_handler", "get_tornado_handler",
"RouterSIO", "RouterSIO",
"SocketID",
"Environ",
"Auth",
"Reason",
"Data",
] ]

126
src/fastsio/async_server.py

@ -1,10 +1,14 @@
import asyncio import asyncio
import inspect
# pyright: reportMissingImports=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownParameterType=false # pyright: reportMissingImports=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownParameterType=false
from typing import Any, AsyncContextManager, Callable, Dict, List, Optional, Set, Union, TYPE_CHECKING, Coroutine from typing import Any, AsyncContextManager, Callable, Dict, List, Optional, Set, Union, TYPE_CHECKING, Coroutine
import engineio import engineio
from . import async_manager, base_server, exceptions, packet from . import async_manager, base_server, exceptions, packet
from .types import SocketID, Environ, Auth, Reason, Data
from pydantic import BaseModel as _PydanticBaseModel
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from .async_admin import InstrumentedAsyncServer from .async_admin import InstrumentedAsyncServer
@ -17,8 +21,8 @@ task_reference_holder: Set[asyncio.Task[Any]] = set()
class AsyncServer(base_server.BaseServer): class AsyncServer(base_server.BaseServer):
# Attribute type hints to aid static type checkers # Attribute type hints to aid static type checkers
manager: async_manager.AsyncManager manager: Any
eio: Any eio: engineio.AsyncServer
packet_class: Any packet_class: Any
handlers: Dict[str, Dict[str, Any]] handlers: Dict[str, Dict[str, Any]]
namespace_handlers: Dict[str, Any] namespace_handlers: Dict[str, Any]
@ -751,28 +755,138 @@ class AsyncServer(base_server.BaseServer):
async def _trigger_event(self, event: str, namespace: Optional[str], *args: Any) -> Any: async def _trigger_event(self, event: str, namespace: Optional[str], *args: Any) -> Any:
"""Invoke an application event handler.""" """Invoke an application event handler."""
# Keep originals to support dependency injection from raw payload
original_args = args
original_sid: Optional[str] = None
if original_args and isinstance(original_args[0], str):
original_sid = original_args[0]
# first see if we have an explicit handler for the event # first see if we have an explicit handler for the event
handler, args = self._get_event_handler(event, namespace, args) handler, args = self._get_event_handler(event, namespace, args)
if handler: if handler:
# Build DI kwargs for supported injections, without breaking
# positional compatibility.
di_kwargs: Dict[str, Any] = {}
try:
sig = inspect.signature(handler) # type: ignore[arg-type]
param_items = list(sig.parameters.items())
except (TypeError, ValueError): # builtins/callables without signature
param_items = []
# Determine which params are already fulfilled positionally
positionally_fulfilled_param_names = []
for idx, (pname, p) in enumerate(param_items):
if idx < len(args) and p.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
positionally_fulfilled_param_names.append(pname)
# Prepare payload for Pydantic, only for non-reserved events
payload_for_model: Any = None
if event not in base_server.BaseServer.reserved_events:
if len(original_args) >= 2:
candidate_payload = original_args[1]
# Only accept single-argument payload for model injection
if len(original_args[1:]) == 1:
payload_for_model = candidate_payload
# Prepare environ/auth for DI
computed_environ: Any = None
if original_sid is not None:
try:
computed_environ = self.get_environ(original_sid, namespace)
except Exception:
computed_environ = None
connect_auth_payload: Any = None
if event == "connect" and len(original_args) >= 3:
connect_auth_payload = original_args[2]
for pname, p in param_items:
ann = p.annotation
# Skip if already provided positionally
# if pname in positionally_fulfilled_param_names:
# continue
# Inject AsyncServer by annotation
try:
from .async_server import AsyncServer as _AsyncServerType # local import to avoid cycles
except Exception: # pragma: no cover
_AsyncServerType = None # type: ignore
if _AsyncServerType is not None and ann is _AsyncServerType:
di_kwargs.setdefault(pname, self)
continue
# Inject SocketID by annotation
if ann is SocketID and original_sid is not None:
di_kwargs.setdefault(pname, original_sid)
continue
# Inject Environ by annotation
if ann is Environ:
di_kwargs.setdefault(pname, computed_environ or {})
continue
# Inject Auth by annotation (connect event carries auth payload)
if ann is Auth:
if event == "connect":
di_kwargs.setdefault(pname, connect_auth_payload or None)
else:
raise ValueError("You can`t use `Auth` not in connect handler")
continue
if ann is Reason:
if event == "disconnect":
di_kwargs.setdefault(pname, args[-1])
else:
raise ValueError("You can`t use `Reason` not in disconnect handler")
continue
if ann is Data:
di_kwargs.setdefault(pname, args[-1])
try:
is_model = isinstance(ann, type) and issubclass(ann, _PydanticBaseModel) # type: ignore[arg-type]
except Exception:
is_model = False
if is_model:
if payload_for_model is None:
raise ValueError(
f"Cannot inject Pydantic model '{ann.__name__}': expected a single payload argument"
)
try:
# Pydantic v2: model_validate
if hasattr(ann, "model_validate"):
di_kwargs.setdefault(pname, ann.model_validate(payload_for_model)) # type: ignore[attr-defined]
else: # Pydantic v1 fallback
di_kwargs.setdefault(pname, ann.parse_obj(payload_for_model)) # type: ignore[attr-defined]
except Exception as exc: # pragma: no cover - validation error path
raise ValueError(
f"Failed to validate payload for '{ann.__name__}': {exc}"
) from exc
if asyncio.iscoroutinefunction(handler): if asyncio.iscoroutinefunction(handler):
try: try:
try: try:
ret = await handler(*args) ret = await handler(**di_kwargs)
except TypeError: except TypeError:
# legacy disconnect events use only one argument # legacy disconnect events use only one argument
if event == "disconnect": if event == "disconnect":
ret = await handler(*args[:-1]) ret = await handler(**di_kwargs)
else: # pragma: no cover else: # pragma: no cover
raise raise
except asyncio.CancelledError: # pragma: no cover except asyncio.CancelledError: # pragma: no cover
ret = None ret = None
else: else:
try: try:
ret = handler(*args) ret = handler(**di_kwargs)
except TypeError: except TypeError:
# legacy disconnect events use only one argument # legacy disconnect events use only one argument
if event == "disconnect": if event == "disconnect":
ret = handler(*args[:-1]) ret = handler(**di_kwargs)
else: # pragma: no cover else: # pragma: no cover
raise raise
return ret return ret

12
src/fastsio/types.py

@ -0,0 +1,12 @@
from typing import NewType
# Public typing alias for better readability in handler annotations
SocketID = NewType("SocketID", str)
Environ = NewType("Environ", dict)
Auth = NewType("Auth", dict)
Data = NewType("Data", dict)
Reason = NewType("Reason", str)
__all__ = ["SocketID", "Environ", "Auth", "Reason", "Data"]
Loading…
Cancel
Save