Browse Source

Merge 9365d9333f into 6df50d40fe

pull/13965/merge
Motov Yurii 5 days ago
committed by GitHub
parent
commit
1d2c84cd35
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 1
      fastapi/__init__.py
  2. 1
      fastapi/dependencies/models.py
  3. 12
      fastapi/dependencies/utils.py
  4. 31
      fastapi/types.py
  5. 167
      tests/test_typed_state.py

1
fastapi/__init__.py

@ -21,5 +21,6 @@ from .param_functions import Security as Security
from .requests import Request as Request
from .responses import Response as Response
from .routing import APIRouter as APIRouter
from .types import TypedState as TypedState
from .websockets import WebSocket as WebSocket
from .websockets import WebSocketDisconnect as WebSocketDisconnect

1
fastapi/dependencies/models.py

@ -29,6 +29,7 @@ class Dependant:
background_tasks_param_name: Optional[str] = None
security_scopes_param_name: Optional[str] = None
security_scopes: Optional[List[str]] = None
state_param_name: Optional[str] = None
use_cache: bool = True
path: Optional[str] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)

12
fastapi/dependencies/utils.py

@ -56,6 +56,7 @@ from fastapi.logger import logger
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.types import RequestState, TypedState
from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel
from pydantic.fields import FieldInfo
@ -335,6 +336,9 @@ def add_non_field_param_to_dependency(
elif lenient_issubclass(type_annotation, SecurityScopes):
dependant.security_scopes_param_name = param_name
return True
elif lenient_issubclass(type_annotation, RequestState):
dependant.state_param_name = param_name
return True
return None
@ -360,7 +364,10 @@ def analyze_param(
use_annotation = annotation
type_annotation = annotation
# Extract Annotated info
if get_origin(use_annotation) is Annotated:
origin = get_origin(annotation)
if origin is TypedState:
type_annotation = RequestState
if origin is Annotated:
annotated_args = get_args(annotation)
type_annotation = annotated_args[0]
fastapi_annotations = [
@ -436,6 +443,7 @@ def analyze_param(
Response,
StarletteBackgroundTasks,
SecurityScopes,
RequestState,
),
):
assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}"
@ -686,6 +694,8 @@ async def solve_dependencies(
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes
)
if dependant.state_param_name:
values[dependant.state_param_name] = TypedState(_state=request.state._state)
return SolvedDependency(
values=values,
errors=errors,

31
fastapi/types.py

@ -1,6 +1,6 @@
import types
from enum import Enum
from typing import Any, Callable, Dict, Set, Type, TypeVar, Union
from typing import Any, Callable, Dict, Generic, Set, Type, TypeVar, Union
from pydantic import BaseModel
@ -8,3 +8,32 @@ DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])
UnionType = getattr(types, "UnionType", Union)
ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str]
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
StateType = TypeVar("StateType", bound=Dict[str, Any])
class RequestState:
pass
class TypedState(RequestState, Generic[StateType]):
def __init__(self, _state: StateType) -> None:
super().__init__()
self._state = _state
def __getattr__(self, item: str) -> Any:
if item.startswith("_"):
# TODO: Restrict overriding of the _state attribute
return object.__getattribute__(self, item)
if item in self._state:
return self._state[item]
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{item}'"
)
def __setattr__(self, key: str, value: Any) -> None:
if key.startswith("_"):
super().__setattr__(key, value)
else:
self._state[key] = value

167
tests/test_typed_state.py

@ -0,0 +1,167 @@
from contextlib import asynccontextmanager
from typing import TypedDict
import pytest
from fastapi import Depends, FastAPI, TypedState
from fastapi.testclient import TestClient
class MyState(TypedDict):
param_1: str
param_2: int
@asynccontextmanager
async def lifespan(app: FastAPI):
state = MyState(param_1="example", param_2=42)
yield state
app = FastAPI(lifespan=lifespan)
@app.get("/read")
async def read_state(state: TypedState[MyState]):
return {"param_1": state._state["param_1"], "param_2": state._state["param_2"]}
async def update_state(state: TypedState[MyState]):
state._state["param_1"] = "Updated"
@app.get("/updated-state", dependencies=[Depends(update_state)])
async def read_updated_state(state: TypedState[MyState]):
return {"param_1": state._state["param_1"], "param_2": state._state["param_2"]}
@app.get("/read-attribute-access")
async def read_attribute(state: TypedState[MyState]):
# This way it's not typed, but attribute access works
return {
"param_1": state.param_1,
"param_2": state.param_2,
}
async def update_state_attribute_access(state: TypedState[MyState]):
state.param_1 = "Updated" # This way it's not typed, but attribute access works
@app.get(
"/updated-state-attribute-access",
dependencies=[Depends(update_state_attribute_access)],
)
async def read_updated_attribute(state: TypedState[MyState]):
# This way it's not typed, but attribute access works
return {
"param_1": state.param_1,
"param_2": state.param_2,
}
@pytest.mark.parametrize(
"path",
[
"/read",
"/read-attribute-access",
],
)
def test_read(path: str):
with TestClient(app) as client:
response = client.get(path)
assert response.status_code == 200
assert response.json() == {"param_1": "example", "param_2": 42}
@pytest.mark.parametrize(
"path",
[
"/updated-state",
"/updated-state-attribute-access",
],
)
def test_read_updated_state_state(path: str):
with TestClient(app) as client:
response = client.get("/updated-state")
assert response.status_code == 200
assert response.json() == {"param_1": "Updated", "param_2": 42}
def test_openapi_schema():
with TestClient(app) as client:
response = client.get("/openapi.json")
assert response.status_code == 200
schema = response.json()
assert schema == {
"info": {
"title": "FastAPI",
"version": "0.1.0",
},
"openapi": "3.1.0",
"paths": {
"/read": {
"get": {
"operationId": "read_state_read_get",
"responses": {
"200": {
"content": {
"application/json": {
"schema": {},
},
},
"description": "Successful Response",
},
},
"summary": "Read State",
},
},
"/read-attribute-access": {
"get": {
"operationId": "read_attribute_read_attribute_access_get",
"responses": {
"200": {
"content": {
"application/json": {
"schema": {},
},
},
"description": "Successful Response",
},
},
"summary": "Read Attribute",
},
},
"/updated-state": {
"get": {
"operationId": "read_updated_state_updated_state_get",
"responses": {
"200": {
"content": {
"application/json": {
"schema": {},
},
},
"description": "Successful Response",
},
},
"summary": "Read Updated State",
},
},
"/updated-state-attribute-access": {
"get": {
"operationId": "read_updated_attribute_updated_state_attribute_access_get",
"responses": {
"200": {
"content": {
"application/json": {
"schema": {},
},
},
"description": "Successful Response",
},
},
"summary": "Read Updated Attribute",
},
},
},
}
Loading…
Cancel
Save