Browse Source

Inject typed request state instance

pull/13965/head
Yurii Motov 5 days ago
parent
commit
6ac02358f0
  1. 1
      fastapi/__init__.py
  2. 1
      fastapi/dependencies/models.py
  3. 12
      fastapi/dependencies/utils.py
  4. 31
      fastapi/types.py

1
fastapi/__init__.py

@ -21,5 +21,6 @@ from .param_functions import Security as Security
from .requests import Request as Request
from .responses import Response as Response
from .routing import APIRouter as APIRouter
from .types import TypedState as TypedState
from .websockets import WebSocket as WebSocket
from .websockets import WebSocketDisconnect as WebSocketDisconnect

1
fastapi/dependencies/models.py

@ -29,6 +29,7 @@ class Dependant:
background_tasks_param_name: Optional[str] = None
security_scopes_param_name: Optional[str] = None
security_scopes: Optional[List[str]] = None
state_param_name: Optional[str] = None
use_cache: bool = True
path: Optional[str] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)

12
fastapi/dependencies/utils.py

@ -56,6 +56,7 @@ from fastapi.logger import logger
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.types import RequestState, TypedState
from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel
from pydantic.fields import FieldInfo
@ -335,6 +336,9 @@ def add_non_field_param_to_dependency(
elif lenient_issubclass(type_annotation, SecurityScopes):
dependant.security_scopes_param_name = param_name
return True
elif lenient_issubclass(type_annotation, RequestState):
dependant.state_param_name = param_name
return True
return None
@ -360,7 +364,10 @@ def analyze_param(
use_annotation = annotation
type_annotation = annotation
# Extract Annotated info
if get_origin(use_annotation) is Annotated:
origin = get_origin(annotation)
if origin is TypedState:
type_annotation = RequestState
if origin is Annotated:
annotated_args = get_args(annotation)
type_annotation = annotated_args[0]
fastapi_annotations = [
@ -436,6 +443,7 @@ def analyze_param(
Response,
StarletteBackgroundTasks,
SecurityScopes,
RequestState,
),
):
assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}"
@ -686,6 +694,8 @@ async def solve_dependencies(
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes
)
if dependant.state_param_name:
values[dependant.state_param_name] = TypedState(_state=request.state._state)
return SolvedDependency(
values=values,
errors=errors,

31
fastapi/types.py

@ -1,6 +1,6 @@
import types
from enum import Enum
from typing import Any, Callable, Dict, Set, Type, TypeVar, Union
from typing import Any, Callable, Dict, Generic, Set, Type, TypeVar, Union
from pydantic import BaseModel
@ -8,3 +8,32 @@ DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])
UnionType = getattr(types, "UnionType", Union)
ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str]
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
StateType = TypeVar("StateType", bound=Dict[str, Any])
class RequestState:
pass
class TypedState(RequestState, Generic[StateType]):
def __init__(self, _state: StateType) -> None:
super().__init__()
self._state = _state
def __getattr__(self, item: str) -> Any:
if item.startswith("_"):
# TODO: Restrict overriding of the _state attribute
return object.__getattribute__(self, item)
if item in self._state:
return self._state[item]
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{item}'"
)
def __setattr__(self, key: str, value: Any) -> None:
if key.startswith("_"):
super().__setattr__(key, value)
else:
self._state[key] = value

Loading…
Cancel
Save