diff --git a/fastapi/__init__.py b/fastapi/__init__.py index b02bf8b4f..3c2968660 100644 --- a/fastapi/__init__.py +++ b/fastapi/__init__.py @@ -21,5 +21,6 @@ from .param_functions import Security as Security from .requests import Request as Request from .responses import Response as Response from .routing import APIRouter as APIRouter +from .types import TypedState as TypedState from .websockets import WebSocket as WebSocket from .websockets import WebSocketDisconnect as WebSocketDisconnect diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 418c11725..0dbb16e1b 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -29,6 +29,7 @@ class Dependant: background_tasks_param_name: Optional[str] = None security_scopes_param_name: Optional[str] = None security_scopes: Optional[List[str]] = None + state_param_name: Optional[str] = None use_cache: bool = True path: Optional[str] = None cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 081b63a8b..07a7de922 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -56,6 +56,7 @@ from fastapi.logger import logger from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes from fastapi.security.open_id_connect_url import OpenIdConnect +from fastapi.types import RequestState, TypedState from fastapi.utils import create_model_field, get_path_param_names from pydantic import BaseModel from pydantic.fields import FieldInfo @@ -335,6 +336,9 @@ def add_non_field_param_to_dependency( elif lenient_issubclass(type_annotation, SecurityScopes): dependant.security_scopes_param_name = param_name return True + elif lenient_issubclass(type_annotation, RequestState): + dependant.state_param_name = param_name + return True return None @@ -360,7 +364,10 @@ def analyze_param( use_annotation = annotation type_annotation = annotation # Extract Annotated info - if get_origin(use_annotation) is Annotated: + origin = get_origin(annotation) + if origin is TypedState: + type_annotation = RequestState + if origin is Annotated: annotated_args = get_args(annotation) type_annotation = annotated_args[0] fastapi_annotations = [ @@ -436,6 +443,7 @@ def analyze_param( Response, StarletteBackgroundTasks, SecurityScopes, + RequestState, ), ): assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}" @@ -686,6 +694,8 @@ async def solve_dependencies( values[dependant.security_scopes_param_name] = SecurityScopes( scopes=dependant.security_scopes ) + if dependant.state_param_name: + values[dependant.state_param_name] = TypedState(_state=request.state._state) return SolvedDependency( values=values, errors=errors, diff --git a/fastapi/types.py b/fastapi/types.py index 3205654c7..c43514dda 100644 --- a/fastapi/types.py +++ b/fastapi/types.py @@ -1,6 +1,6 @@ import types from enum import Enum -from typing import Any, Callable, Dict, Set, Type, TypeVar, Union +from typing import Any, Callable, Dict, Generic, Set, Type, TypeVar, Union from pydantic import BaseModel @@ -8,3 +8,32 @@ DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) UnionType = getattr(types, "UnionType", Union) ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str] IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]] + + +StateType = TypeVar("StateType", bound=Dict[str, Any]) + + +class RequestState: + pass + + +class TypedState(RequestState, Generic[StateType]): + def __init__(self, _state: StateType) -> None: + super().__init__() + self._state = _state + + def __getattr__(self, item: str) -> Any: + if item.startswith("_"): + # TODO: Restrict overriding of the _state attribute + return object.__getattribute__(self, item) + if item in self._state: + return self._state[item] + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + + def __setattr__(self, key: str, value: Any) -> None: + if key.startswith("_"): + super().__setattr__(key, value) + else: + self._state[key] = value