Browse Source

Added support for lifespan-scoped dependencies using a new dependency_scope argument.

pull/12529/head
Nir Schulman 9 months ago
parent
commit
25407d039a
  1. 56
      fastapi/applications.py
  2. 46
      fastapi/dependencies/models.py
  3. 247
      fastapi/dependencies/utils.py
  4. 46
      fastapi/lifespan.py
  5. 6
      fastapi/openapi/utils.py
  6. 32
      fastapi/param_functions.py
  7. 23
      fastapi/params.py
  8. 55
      fastapi/routing.py
  9. 703
      tests/test_lifespan_scoped_dependencies.py
  10. 21
      tests/test_params_repr.py
  11. 16
      tests/test_router_events.py

56
fastapi/applications.py

@ -1,6 +1,8 @@
from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum from enum import Enum
from typing import ( from typing import (
Any, Any,
AsyncGenerator,
Awaitable, Awaitable,
Callable, Callable,
Coroutine, Coroutine,
@ -15,12 +17,14 @@ from typing import (
from fastapi import routing from fastapi import routing
from fastapi.datastructures import Default, DefaultPlaceholder from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.utils import is_coroutine_callable
from fastapi.exception_handlers import ( from fastapi.exception_handlers import (
http_exception_handler, http_exception_handler,
request_validation_exception_handler, request_validation_exception_handler,
websocket_request_validation_exception_handler, websocket_request_validation_exception_handler,
) )
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.lifespan import resolve_lifespan_dependants
from fastapi.logger import logger from fastapi.logger import logger
from fastapi.openapi.docs import ( from fastapi.openapi.docs import (
get_redoc_html, get_redoc_html,
@ -29,9 +33,11 @@ from fastapi.openapi.docs import (
) )
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from fastapi.params import Depends from fastapi.params import Depends
from fastapi.routing import merge_lifespan_context
from fastapi.types import DecoratedCallable, IncEx from fastapi.types import DecoratedCallable, IncEx
from fastapi.utils import generate_unique_id from fastapi.utils import generate_unique_id
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import State from starlette.datastructures import State
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.middleware import Middleware from starlette.middleware import Middleware
@ -929,12 +935,24 @@ class FastAPI(Starlette):
""" """
), ),
] = {} ] = {}
if lifespan is None:
lifespan = FastAPI._internal_lifespan
else:
lifespan = merge_lifespan_context(
FastAPI._internal_lifespan,
lifespan
)
# Since we always use a lifespan, starlette will no longer run event
# handlers which are defined in the scope of the application.
# We therefore need to call them ourselves.
self._on_startup = on_startup or []
self._on_shutdown = on_shutdown or []
self.router: routing.APIRouter = routing.APIRouter( self.router: routing.APIRouter = routing.APIRouter(
routes=routes, routes=routes,
redirect_slashes=redirect_slashes, redirect_slashes=redirect_slashes,
dependency_overrides_provider=self, dependency_overrides_provider=self,
on_startup=on_startup,
on_shutdown=on_shutdown,
lifespan=lifespan, lifespan=lifespan,
default_response_class=default_response_class, default_response_class=default_response_class,
dependencies=dependencies, dependencies=dependencies,
@ -963,6 +981,32 @@ class FastAPI(Starlette):
self.middleware_stack: Union[ASGIApp, None] = None self.middleware_stack: Union[ASGIApp, None] = None
self.setup() self.setup()
@asynccontextmanager
async def _internal_lifespan(self) -> AsyncGenerator[dict[str, Any], None]:
async with AsyncExitStack() as exit_stack:
lifespan_scoped_dependencies = await resolve_lifespan_dependants(
app=self,
async_exit_stack=exit_stack
)
try:
for handler in self._on_startup:
if is_coroutine_callable(handler):
await handler()
else:
await run_in_threadpool(handler)
yield {
"__fastapi__": {
"lifespan_scoped_dependencies": lifespan_scoped_dependencies
}
}
finally:
for handler in self._on_shutdown:
if is_coroutine_callable(handler):
await handler()
else:
await run_in_threadpool(handler)
def openapi(self) -> Dict[str, Any]: def openapi(self) -> Dict[str, Any]:
""" """
Generate the OpenAPI schema of the application. This is called by FastAPI Generate the OpenAPI schema of the application. This is called by FastAPI
@ -4492,7 +4536,13 @@ class FastAPI(Starlette):
Read more about it in the Read more about it in the
[FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/#alternative-events-deprecated). [FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/#alternative-events-deprecated).
""" """
return self.router.on_event(event_type) def decorator(func: DecoratedCallable) -> DecoratedCallable:
if event_type == "startup":
self._on_startup.append(func)
else:
self._on_shutdown.append(func)
return func
return decorator
def middleware( def middleware(
self, self,

46
fastapi/dependencies/models.py

@ -1,8 +1,9 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Sequence, Tuple from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from fastapi._compat import ModelField from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from typing_extensions import TypeAlias
@dataclass @dataclass
@ -11,17 +12,41 @@ class SecurityRequirement:
scopes: Optional[Sequence[str]] = None scopes: Optional[Sequence[str]] = None
LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], str], Callable[..., Any]]
@dataclass
class LifespanDependant:
caller: Callable[..., Any]
dependencies: List["LifespanDependant"] = field(default_factory=list)
name: Optional[str] = None
call: Optional[Callable[..., Any]] = None
use_cache: bool = True
cache_key: LifespanDependantCacheKey = field(init=False)
def __post_init__(self) -> None:
if self.use_cache:
self.cache_key = self.call
else:
self.cache_key = (self.caller, self.name)
EndpointDependantCacheKey: TypeAlias = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
@dataclass @dataclass
class Dependant: class EndpointDependant:
endpoint_dependencies: List["EndpointDependant"] = field(default_factory=list)
lifespan_dependencies: List[LifespanDependant] = field(default_factory=list)
name: Optional[str] = None
call: Optional[Callable[..., Any]] = None
use_cache: bool = True
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(
init=False)
path_params: List[ModelField] = field(default_factory=list) path_params: List[ModelField] = field(default_factory=list)
query_params: List[ModelField] = field(default_factory=list) query_params: List[ModelField] = field(default_factory=list)
header_params: List[ModelField] = field(default_factory=list) header_params: List[ModelField] = field(default_factory=list)
cookie_params: List[ModelField] = field(default_factory=list) cookie_params: List[ModelField] = field(default_factory=list)
body_params: List[ModelField] = field(default_factory=list) body_params: List[ModelField] = field(default_factory=list)
dependencies: List["Dependant"] = field(default_factory=list)
security_requirements: List[SecurityRequirement] = field(default_factory=list) security_requirements: List[SecurityRequirement] = field(default_factory=list)
name: Optional[str] = None
call: Optional[Callable[..., Any]] = None
request_param_name: Optional[str] = None request_param_name: Optional[str] = None
websocket_param_name: Optional[str] = None websocket_param_name: Optional[str] = None
http_connection_param_name: Optional[str] = None http_connection_param_name: Optional[str] = None
@ -29,9 +54,16 @@ class Dependant:
background_tasks_param_name: Optional[str] = None background_tasks_param_name: Optional[str] = None
security_scopes_param_name: Optional[str] = None security_scopes_param_name: Optional[str] = None
security_scopes: Optional[List[str]] = None security_scopes: Optional[List[str]] = None
use_cache: bool = True
path: Optional[str] = None path: Optional[str] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
# Kept for backwards compatibility
@property
def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]:
return tuple(self.endpoint_dependencies + self.lifespan_dependencies)
# Kept for backwards compatibility
Dependant = EndpointDependant
CacheKey: TypeAlias = Union[EndpointDependantCacheKey, LifespanDependantCacheKey]

247
fastapi/dependencies/utils.py

@ -51,7 +51,14 @@ from fastapi.concurrency import (
asynccontextmanager, asynccontextmanager,
contextmanager_in_threadpool, contextmanager_in_threadpool,
) )
from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.dependencies.models import (
CacheKey,
EndpointDependant,
LifespanDependant,
LifespanDependantCacheKey,
SecurityRequirement,
)
from fastapi.exceptions import FastAPIError
from fastapi.logger import logger from fastapi.logger import logger
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes from fastapi.security.oauth2 import OAuth2, SecurityScopes
@ -112,8 +119,9 @@ def get_param_sub_dependant(
param_name: str, param_name: str,
depends: params.Depends, depends: params.Depends,
path: str, path: str,
caller: Callable[..., Any],
security_scopes: Optional[List[str]] = None, security_scopes: Optional[List[str]] = None,
) -> Dependant: ) -> Union[EndpointDependant, LifespanDependant]:
assert depends.dependency assert depends.dependency
return get_sub_dependant( return get_sub_dependant(
depends=depends, depends=depends,
@ -121,14 +129,25 @@ def get_param_sub_dependant(
path=path, path=path,
name=param_name, name=param_name,
security_scopes=security_scopes, security_scopes=security_scopes,
caller=caller,
) )
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: def get_parameterless_sub_dependant(
*,
depends: params.Depends,
path: str,
caller: Callable[..., Any]
) -> Union[EndpointDependant, LifespanDependant]:
assert callable( assert callable(
depends.dependency depends.dependency
), "A parameter-less dependency must have a callable dependency" ), "A parameter-less dependency must have a callable dependency"
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) return get_sub_dependant(
depends=depends,
dependency=depends.dependency,
path=path,
caller=caller
)
def get_sub_dependant( def get_sub_dependant(
@ -136,9 +155,18 @@ def get_sub_dependant(
depends: params.Depends, depends: params.Depends,
dependency: Callable[..., Any], dependency: Callable[..., Any],
path: str, path: str,
caller: Callable[..., Any],
name: Optional[str] = None, name: Optional[str] = None,
security_scopes: Optional[List[str]] = None, security_scopes: Optional[List[str]] = None,
) -> Dependant: ) -> Union[EndpointDependant, LifespanDependant]:
if depends.dependency_scope == "lifespan":
return get_lifespan_dependant(
caller=caller,
call=depends.dependency,
name=name,
use_cache=depends.use_cache
)
elif depends.dependency_scope == "endpoint":
security_requirement = None security_requirement = None
security_scopes = security_scopes or [] security_scopes = security_scopes or []
if isinstance(depends, params.Security): if isinstance(depends, params.Security):
@ -151,7 +179,7 @@ def get_sub_dependant(
security_requirement = SecurityRequirement( security_requirement = SecurityRequirement(
security_scheme=dependency, scopes=use_scopes security_scheme=dependency, scopes=use_scopes
) )
sub_dependant = get_dependant( sub_dependant = get_endpoint_dependant(
path=path, path=path,
call=dependency, call=dependency,
name=name, name=name,
@ -161,32 +189,35 @@ def get_sub_dependant(
if security_requirement: if security_requirement:
sub_dependant.security_requirements.append(security_requirement) sub_dependant.security_requirements.append(security_requirement)
return sub_dependant return sub_dependant
else:
raise ValueError(
CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] f"Dependency {name} of {caller} has an invalid "
f"sub-dependency scope: {depends.dependency_scope}"
)
def get_flat_dependant( def get_flat_dependant(
dependant: Dependant, dependant: EndpointDependant,
*, *,
skip_repeats: bool = False, skip_repeats: bool = False,
visited: Optional[List[CacheKey]] = None, visited: Optional[List[CacheKey]] = None,
) -> Dependant: ) -> EndpointDependant:
if visited is None: if visited is None:
visited = [] visited = []
visited.append(dependant.cache_key) visited.append(dependant.cache_key)
flat_dependant = Dependant( flat_dependant = EndpointDependant(
path_params=dependant.path_params.copy(), path_params=dependant.path_params.copy(),
query_params=dependant.query_params.copy(), query_params=dependant.query_params.copy(),
header_params=dependant.header_params.copy(), header_params=dependant.header_params.copy(),
cookie_params=dependant.cookie_params.copy(), cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(), body_params=dependant.body_params.copy(),
security_requirements=dependant.security_requirements.copy(), security_requirements=dependant.security_requirements.copy(),
lifespan_dependencies=dependant.lifespan_dependencies.copy(),
use_cache=dependant.use_cache, use_cache=dependant.use_cache,
path=dependant.path, path=dependant.path
) )
for sub_dependant in dependant.dependencies: for sub_dependant in dependant.endpoint_dependencies:
if skip_repeats and sub_dependant.cache_key in visited: if skip_repeats and sub_dependant.cache_key in visited:
continue continue
flat_sub = get_flat_dependant( flat_sub = get_flat_dependant(
@ -198,6 +229,7 @@ def get_flat_dependant(
flat_dependant.cookie_params.extend(flat_sub.cookie_params) flat_dependant.cookie_params.extend(flat_sub.cookie_params)
flat_dependant.body_params.extend(flat_sub.body_params) flat_dependant.body_params.extend(flat_sub.body_params)
flat_dependant.security_requirements.extend(flat_sub.security_requirements) flat_dependant.security_requirements.extend(flat_sub.security_requirements)
flat_dependant.lifespan_dependencies.extend(flat_sub.lifespan_dependencies)
return flat_dependant return flat_dependant
@ -211,7 +243,7 @@ def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
return fields return fields
def get_flat_params(dependant: Dependant) -> List[ModelField]: def get_flat_params(dependant: EndpointDependant) -> List[ModelField]:
flat_dependant = get_flat_dependant(dependant, skip_repeats=True) flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
path_params = _get_flat_fields_from_params(flat_dependant.path_params) path_params = _get_flat_fields_from_params(flat_dependant.path_params)
query_params = _get_flat_fields_from_params(flat_dependant.query_params) query_params = _get_flat_fields_from_params(flat_dependant.query_params)
@ -254,18 +286,64 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
return get_typed_annotation(annotation, globalns) return get_typed_annotation(annotation, globalns)
def get_dependant( def get_lifespan_dependant(
*,
caller: Callable[..., Any],
call: Callable[..., Any],
name: Optional[str] = None,
use_cache: bool = True,
) -> LifespanDependant:
dependency_signature = get_typed_signature(call)
signature_params = dependency_signature.parameters
dependant = LifespanDependant(
call=call,
name=name,
use_cache=use_cache,
caller=caller
)
for param_name, param in signature_params.items():
param_details = analyze_param(
param_name=param_name,
annotation=param.annotation,
value=param.default,
is_path_param=False,
)
if param_details.depends is None:
raise FastAPIError(
f"Lifespan dependency {dependant.name} was defined with an "
f"invalid argument {param_name}. Lifespan dependencies may "
f"only use other lifespan dependencies as arguments.")
if param_details.depends.dependency_scope != "lifespan":
raise FastAPIError(
"Lifespan dependency may not use "
"sub-dependencies of other scopes."
)
sub_dependant = get_lifespan_dependant(
name=param_name,
call=param_details.depends.dependency,
use_cache=param_details.depends.use_cache,
caller=call
)
dependant.dependencies.append(sub_dependant)
return dependant
def get_endpoint_dependant(
*, *,
path: str, path: str,
call: Callable[..., Any], call: Callable[..., Any],
name: Optional[str] = None, name: Optional[str] = None,
security_scopes: Optional[List[str]] = None, security_scopes: Optional[List[str]] = None,
use_cache: bool = True, use_cache: bool = True,
) -> Dependant: ) -> EndpointDependant:
path_param_names = get_path_param_names(path) path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call) endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters signature_params = endpoint_signature.parameters
dependant = Dependant( dependant = EndpointDependant(
call=call, call=call,
name=name, name=name,
path=path, path=path,
@ -281,13 +359,28 @@ def get_dependant(
is_path_param=is_path_param, is_path_param=is_path_param,
) )
if param_details.depends is not None: if param_details.depends is not None:
if param_details.depends.dependency_scope == "endpoint":
sub_dependant = get_param_sub_dependant( sub_dependant = get_param_sub_dependant(
param_name=param_name, param_name=param_name,
depends=param_details.depends, depends=param_details.depends,
path=path, path=path,
security_scopes=security_scopes, security_scopes=security_scopes,
caller=call,
)
dependant.endpoint_dependencies.append(sub_dependant)
elif param_details.depends.dependency_scope == "lifespan":
sub_dependant = get_lifespan_dependant(
caller=call,
call=param_details.depends.dependency,
name=param_name,
use_cache=param_details.depends.use_cache,
)
dependant.lifespan_dependencies.append(sub_dependant)
else:
raise FastAPIError(
f"Dependency \"{param_name}\" of `{call}` has an invalid "
f"sub-dependency scope: \"{param_details.depends.dependency_scope}\""
) )
dependant.dependencies.append(sub_dependant)
continue continue
if add_non_field_param_to_dependency( if add_non_field_param_to_dependency(
param_name=param_name, param_name=param_name,
@ -306,8 +399,12 @@ def get_dependant(
return dependant return dependant
# Kept for backwards compatibility
get_dependant = get_endpoint_dependant
def add_non_field_param_to_dependency( def add_non_field_param_to_dependency(
*, param_name: str, type_annotation: Any, dependant: Dependant *, param_name: str, type_annotation: Any, dependant: EndpointDependant
) -> Optional[bool]: ) -> Optional[bool]:
if lenient_issubclass(type_annotation, Request): if lenient_issubclass(type_annotation, Request):
dependant.request_param_name = param_name dependant.request_param_name = param_name
@ -501,7 +598,7 @@ def analyze_param(
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field) return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: def add_param_to_fields(*, field: ModelField, dependant: EndpointDependant) -> None:
field_info = field.field_info field_info = field.field_info
field_info_in = getattr(field_info, "in_", None) field_info_in = getattr(field_info, "in_", None)
if field_info_in == params.ParamTypes.path: if field_info_in == params.ParamTypes.path:
@ -550,6 +647,82 @@ async def solve_generator(
return await stack.enter_async_context(cm) return await stack.enter_async_context(cm)
@dataclass
class SolvedLifespanDependant:
value: Any
dependency_cache: Dict[Callable[..., Any], Any]
async def solve_lifespan_dependant(
*,
dependant: LifespanDependant,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[LifespanDependantCacheKey, Callable[..., Any]]] = None,
async_exit_stack: AsyncExitStack,
) -> SolvedLifespanDependant:
dependency_cache = dependency_cache or {}
if dependant.use_cache and dependant.cache_key in dependency_cache:
return SolvedLifespanDependant(
value=dependency_cache[dependant.cache_key],
dependency_cache=dependency_cache,
)
dependency_arguments: Dict[str, Any] = {}
sub_dependant: LifespanDependant
for sub_dependant in dependant.dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast(
Callable[..., Any], sub_dependant.cache_key
)
assert sub_dependant.name, (
"Lifespan scoped dependencies should not be able to have "
"subdependencies with no name"
)
sub_dependant_to_solve = sub_dependant
if (
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
):
original_call = sub_dependant.call
call = getattr(
dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call)
sub_dependant_to_solve = get_lifespan_dependant(
call=call,
name=sub_dependant.name,
caller=dependant.call
)
solved_sub_dependant = await solve_lifespan_dependant(
dependant=sub_dependant_to_solve,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack,
)
dependency_cache.update(solved_sub_dependant.dependency_cache)
dependency_arguments[sub_dependant.name] = solved_sub_dependant.value
if is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call):
value = await solve_generator(
call=dependant.call,
stack=async_exit_stack,
sub_values=dependency_arguments
)
elif is_coroutine_callable(dependant.call):
value = await dependant.call(**dependency_arguments)
else:
value = await run_in_threadpool(dependant.call, **dependency_arguments)
if dependant.cache_key not in dependency_cache:
dependency_cache[dependant.cache_key] = value
return SolvedLifespanDependant(
value=value,
dependency_cache=dependency_cache,
)
@dataclass @dataclass
class SolvedDependency: class SolvedDependency:
values: Dict[str, Any] values: Dict[str, Any]
@ -562,7 +735,7 @@ class SolvedDependency:
async def solve_dependencies( async def solve_dependencies(
*, *,
request: Union[Request, WebSocket], request: Union[Request, WebSocket],
dependant: Dependant, dependant: EndpointDependant,
body: Optional[Union[Dict[str, Any], FormData]] = None, body: Optional[Union[Dict[str, Any], FormData]] = None,
background_tasks: Optional[StarletteBackgroundTasks] = None, background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None, response: Optional[Response] = None,
@ -573,13 +746,35 @@ async def solve_dependencies(
) -> SolvedDependency: ) -> SolvedDependency:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
errors: List[Any] = [] errors: List[Any] = []
for sub_dependant in dependant.lifespan_dependencies:
if sub_dependant.name is None:
continue
try:
lifespan_scoped_dependencies = request.state.__fastapi__[
"lifespan_scoped_dependencies"]
except AttributeError as e:
raise FastAPIError(
"FastAPI's internal lifespan was not initialized"
) from e
try:
value = lifespan_scoped_dependencies[sub_dependant.cache_key]
except KeyError as e:
raise FastAPIError(
f"Dependency {sub_dependant.name} of {dependant.call} "
f"was not initialized."
) from e
values[sub_dependant.name] = value
if response is None: if response is None:
response = Response() response = Response()
del response.headers["content-length"] del response.headers["content-length"]
response.status_code = None # type: ignore response.status_code = None # type: ignore
dependency_cache = dependency_cache or {} dependency_cache = dependency_cache or {}
sub_dependant: Dependant for sub_dependant in dependant.endpoint_dependencies:
for sub_dependant in dependant.dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast( sub_dependant.cache_key = cast(
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
@ -595,7 +790,7 @@ async def solve_dependencies(
dependency_overrides_provider, "dependency_overrides", {} dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call) ).get(original_call, original_call)
use_path: str = sub_dependant.path # type: ignore use_path: str = sub_dependant.path # type: ignore
use_sub_dependant = get_dependant( use_sub_dependant = get_endpoint_dependant(
path=use_path, path=use_path,
call=call, call=call,
name=sub_dependant.name, name=sub_dependant.name,
@ -910,7 +1105,7 @@ async def request_body_to_args(
def get_body_field( def get_body_field(
*, flat_dependant: Dependant, name: str, embed_body_fields: bool *, flat_dependant: EndpointDependant, name: str, embed_body_fields: bool
) -> Optional[ModelField]: ) -> Optional[ModelField]:
""" """
Get a ModelField representing the request body for a path operation, combining Get a ModelField representing the request body for a path operation, combining

46
fastapi/lifespan.py

@ -0,0 +1,46 @@
from __future__ import annotations
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Callable, Dict, List
from fastapi.dependencies.models import LifespanDependant, LifespanDependantCacheKey
from fastapi.dependencies.utils import solve_lifespan_dependant
from fastapi.routing import APIRoute
if TYPE_CHECKING:
from fastapi import FastAPI
def _get_lifespan_dependants(app: FastAPI) -> List[LifespanDependant]:
lifespan_dependants_cache: Dict[LifespanDependantCacheKey, LifespanDependant] = {}
for route in app.router.routes:
if not isinstance(route, APIRoute):
continue
for sub_dependant in route.lifespan_dependencies:
if sub_dependant.cache_key in lifespan_dependants_cache:
continue
lifespan_dependants_cache[sub_dependant.cache_key] = sub_dependant
return list(lifespan_dependants_cache.values())
async def resolve_lifespan_dependants(
*,
app: FastAPI,
async_exit_stack: AsyncExitStack
) -> Dict[LifespanDependantCacheKey, Callable[..., Any]]:
lifespan_dependants = _get_lifespan_dependants(app)
dependency_cache: Dict[LifespanDependantCacheKey, Callable[..., Any]] = {}
for lifespan_dependant in lifespan_dependants:
solved_dependency = await solve_lifespan_dependant(
dependant=lifespan_dependant,
dependency_overrides_provider=app,
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack
)
dependency_cache.update(solved_dependency.dependency_cache)
return dependency_cache

6
fastapi/openapi/utils.py

@ -15,7 +15,7 @@ from fastapi._compat import (
lenient_issubclass, lenient_issubclass,
) )
from fastapi.datastructures import DefaultPlaceholder from fastapi.datastructures import DefaultPlaceholder
from fastapi.dependencies.models import Dependant from fastapi.dependencies.models import EndpointDependant
from fastapi.dependencies.utils import ( from fastapi.dependencies.utils import (
_get_flat_fields_from_params, _get_flat_fields_from_params,
get_flat_dependant, get_flat_dependant,
@ -75,7 +75,7 @@ status_code_ranges: Dict[str, str] = {
def get_openapi_security_definitions( def get_openapi_security_definitions(
flat_dependant: Dependant, flat_dependant: EndpointDependant,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
security_definitions = {} security_definitions = {}
operation_security = [] operation_security = []
@ -93,7 +93,7 @@ def get_openapi_security_definitions(
def _get_openapi_operation_parameters( def _get_openapi_operation_parameters(
*, *,
dependant: Dependant, dependant: EndpointDependant,
schema_generator: GenerateJsonSchema, schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap, model_name_map: ModelNameMap,
field_mapping: Dict[ field_mapping: Dict[

32
fastapi/param_functions.py

@ -1,8 +1,11 @@
from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional, Sequence, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from fastapi import params from fastapi import params
from fastapi._compat import Undefined from fastapi._compat import Undefined
from fastapi.openapi.models import Example from fastapi.openapi.models import Example
from fastapi.params import DependencyScope
from typing_extensions import Annotated, Doc, deprecated from typing_extensions import Annotated, Doc, deprecated
_Unset: Any = Undefined _Unset: Any = Undefined
@ -2244,6 +2247,33 @@ def Depends( # noqa: N802
""" """
), ),
] = True, ] = True,
dependency_scope: Annotated[
DependencyScope,
Doc(
"""
The scope in which the dependency value should be evaluated. Can be
either `"endpoint"` or `"lifespan"`.
If `dependency_scope` is set to "endpoint" (the default), the
dependency will be setup and teardown for every request.
If `dependency_scope` is set to `"lifespan"` the dependency would
be setup at the start of the entire application's lifespan. The
evaluated dependency would be then reused across all endpoints.
The dependency would be teared down as a part of the application's
shutdown process.
Note that dependencies defined with the `"endpoint"` scope may use
sub-dependencies defined with the `"lifespan"` scope, but not the
other way around;
Dependencies defined with the `"lifespan"` scope may not use
sub-dependencies with `"endpoint"` scope, nor can they use
other "endpoint scoped" arguments such as "Path", "Body", "Query",
or any other annotation which does not make sense in a scope of an
application's entire lifespan.
"""
)
] = "endpoint"
) -> Any: ) -> Any:
""" """
Declare a FastAPI dependency. Declare a FastAPI dependency.
@ -2274,7 +2304,7 @@ def Depends( # noqa: N802
return commons return commons
``` ```
""" """
return params.Depends(dependency=dependency, use_cache=use_cache) return params.Depends(dependency=dependency, use_cache=use_cache, dependency_scope=dependency_scope)
def Security( # noqa: N802 def Security( # noqa: N802

23
fastapi/params.py

@ -4,11 +4,12 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from fastapi.openapi.models import Example from fastapi.openapi.models import Example
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from typing_extensions import Annotated, deprecated from typing_extensions import Annotated, Literal, TypeAlias, deprecated
from ._compat import PYDANTIC_V2, PYDANTIC_VERSION, Undefined from ._compat import PYDANTIC_V2, PYDANTIC_VERSION, Undefined
_Unset: Any = Undefined _Unset: Any = Undefined
DependencyScope: TypeAlias = Literal["endpoint", "lifespan"]
class ParamTypes(Enum): class ParamTypes(Enum):
@ -759,15 +760,25 @@ class File(Form):
class Depends: class Depends:
def __init__( def __init__(
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True self,
dependency: Optional[Callable[..., Any]] = None,
*,
use_cache: bool = True,
dependency_scope: DependencyScope = "endpoint"
): ):
self.dependency = dependency self.dependency = dependency
self.use_cache = use_cache self.use_cache = use_cache
self.dependency_scope = dependency_scope
def __repr__(self) -> str: def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
cache = "" if self.use_cache else ", use_cache=False" cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({attr}{cache})" if self.dependency_scope == "endpoint":
dependency_scope = ""
else:
dependency_scope = f", dependency_scope=\"{self.dependency_scope}\""
return f"{self.__class__.__name__}({attr}{cache}{dependency_scope})"
class Security(Depends): class Security(Depends):
@ -778,5 +789,9 @@ class Security(Depends):
scopes: Optional[Sequence[str]] = None, scopes: Optional[Sequence[str]] = None,
use_cache: bool = True, use_cache: bool = True,
): ):
super().__init__(dependency=dependency, use_cache=use_cache) super().__init__(
dependency=dependency,
use_cache=use_cache,
dependency_scope="endpoint"
)
self.scopes = scopes or [] self.scopes = scopes or []

55
fastapi/routing.py

@ -31,11 +31,11 @@ from fastapi._compat import (
lenient_issubclass, lenient_issubclass,
) )
from fastapi.datastructures import Default, DefaultPlaceholder from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant from fastapi.dependencies.models import EndpointDependant, LifespanDependant
from fastapi.dependencies.utils import ( from fastapi.dependencies.utils import (
_should_embed_body_fields, _should_embed_body_fields,
get_body_field, get_body_field,
get_dependant, get_endpoint_dependant,
get_flat_dependant, get_flat_dependant,
get_parameterless_sub_dependant, get_parameterless_sub_dependant,
get_typed_return_annotation, get_typed_return_annotation,
@ -73,7 +73,7 @@ from starlette.routing import (
from starlette.routing import Mount as Mount # noqa from starlette.routing import Mount as Mount # noqa
from starlette.types import AppType, ASGIApp, Lifespan, Scope from starlette.types import AppType, ASGIApp, Lifespan, Scope
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
from typing_extensions import Annotated, Doc, deprecated from typing_extensions import Annotated, Doc, assert_never, deprecated
def _prepare_response_content( def _prepare_response_content(
@ -123,7 +123,7 @@ def _prepare_response_content(
return res return res
def _merge_lifespan_context( def merge_lifespan_context(
original_context: Lifespan[Any], nested_context: Lifespan[Any] original_context: Lifespan[Any], nested_context: Lifespan[Any]
) -> Lifespan[Any]: ) -> Lifespan[Any]:
@asynccontextmanager @asynccontextmanager
@ -202,7 +202,7 @@ async def serialize_response(
async def run_endpoint_function( async def run_endpoint_function(
*, dependant: Dependant, values: Dict[str, Any], is_coroutine: bool *, dependant: EndpointDependant, values: Dict[str, Any], is_coroutine: bool
) -> Any: ) -> Any:
# Only called by get_request_handler. Has been split into its own function to # Only called by get_request_handler. Has been split into its own function to
# facilitate profiling endpoints, since inner functions are harder to profile. # facilitate profiling endpoints, since inner functions are harder to profile.
@ -215,7 +215,7 @@ async def run_endpoint_function(
def get_request_handler( def get_request_handler(
dependant: Dependant, dependant: EndpointDependant,
body_field: Optional[ModelField] = None, body_field: Optional[ModelField] = None,
status_code: Optional[int] = None, status_code: Optional[int] = None,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse), response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse),
@ -358,7 +358,7 @@ def get_request_handler(
def get_websocket_app( def get_websocket_app(
dependant: Dependant, dependant: EndpointDependant,
dependency_overrides_provider: Optional[Any] = None, dependency_overrides_provider: Optional[Any] = None,
embed_body_fields: bool = False, embed_body_fields: bool = False,
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
@ -400,12 +400,20 @@ class APIWebSocketRoute(routing.WebSocketRoute):
self.name = get_name(endpoint) if name is None else name self.name = get_name(endpoint) if name is None else name
self.dependencies = list(dependencies or []) self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_dependant(path=self.path_format, call=self.endpoint) self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]: for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert( sub_dependant = get_parameterless_sub_dependant(
0, depends=depends,
get_parameterless_sub_dependant(depends=depends, path=self.path_format), path=self.path_format,
caller=self
) )
if depends.dependency_scope == "endpoint":
self.dependant.endpoint_dependencies.insert(0, sub_dependant)
elif depends.dependency_scope == "lifespan":
self.dependant.lifespan_dependencies.insert(0, sub_dependant)
else:
assert_never(depends.dependency_scope)
self._flat_dependant = get_flat_dependant(self.dependant) self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields( self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params self._flat_dependant.body_params
@ -424,6 +432,10 @@ class APIWebSocketRoute(routing.WebSocketRoute):
child_scope["route"] = self child_scope["route"] = self
return match, child_scope return match, child_scope
@property
def lifespan_dependencies(self) -> List[LifespanDependant]:
return self._flat_dependant.lifespan_dependencies
class APIRoute(routing.Route): class APIRoute(routing.Route):
def __init__( def __init__(
@ -549,12 +561,19 @@ class APIRoute(routing.Route):
self.response_fields = {} self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable" assert callable(endpoint), "An endpoint must be a callable"
self.dependant = get_dependant(path=self.path_format, call=self.endpoint) self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]: for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert( sub_dependant = get_parameterless_sub_dependant(
0, depends=depends,
get_parameterless_sub_dependant(depends=depends, path=self.path_format), path=self.path_format,
caller=self.__call__
) )
if depends.dependency_scope == "endpoint":
self.dependant.endpoint_dependencies.insert(0, sub_dependant)
elif depends.dependency_scope == "lifespan":
self.dependant.lifespan_dependencies.insert(0, sub_dependant)
else:
assert_never(depends.dependency_scope)
self._flat_dependant = get_flat_dependant(self.dependant) self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields( self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params self._flat_dependant.body_params
@ -589,6 +608,10 @@ class APIRoute(routing.Route):
child_scope["route"] = self child_scope["route"] = self
return match, child_scope return match, child_scope
@property
def lifespan_dependencies(self) -> List[LifespanDependant]:
return self._flat_dependant.lifespan_dependencies
class APIRouter(routing.Router): class APIRouter(routing.Router):
""" """
@ -1356,7 +1379,7 @@ class APIRouter(routing.Router):
self.add_event_handler("startup", handler) self.add_event_handler("startup", handler)
for handler in router.on_shutdown: for handler in router.on_shutdown:
self.add_event_handler("shutdown", handler) self.add_event_handler("shutdown", handler)
self.lifespan_context = _merge_lifespan_context( self.lifespan_context = merge_lifespan_context(
self.lifespan_context, self.lifespan_context,
router.lifespan_context, router.lifespan_context,
) )

703
tests/test_lifespan_scoped_dependencies.py

@ -0,0 +1,703 @@
from enum import StrEnum, auto
from typing import Any, AsyncGenerator, List, Tuple, TypeVar
import pytest
from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Cookie,
Depends,
FastAPI,
File,
Form,
Header,
Path,
Query,
)
from fastapi.exceptions import FastAPIError
from fastapi.params import Security
from fastapi.security import SecurityScopes
from starlette.testclient import TestClient
from typing_extensions import Annotated, Generator, Literal, assert_never
T = TypeVar('T')
class DependencyStyle(StrEnum):
SYNC_FUNCTION = auto()
ASYNC_FUNCTION = auto()
SYNC_GENERATOR = auto()
ASYNC_GENERATOR = auto()
class DependencyFactory:
def __init__(
self,
dependency_style: DependencyStyle, *,
should_error: bool = False
):
self.activation_times = 0
self.deactivation_times = 0
self.dependency_style = dependency_style
self._should_error = should_error
def get_dependency(self):
if self.dependency_style == DependencyStyle.SYNC_FUNCTION:
return self._synchronous_function_dependency
if self.dependency_style == DependencyStyle.SYNC_GENERATOR:
return self._synchronous_generator_dependency
if self.dependency_style == DependencyStyle.ASYNC_FUNCTION:
return self._asynchronous_function_dependency
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR:
return self._asynchronous_generator_dependency
assert_never(self.dependency_style)
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
yield self.activation_times
self.deactivation_times += 1
def _synchronous_generator_dependency(self) -> Generator[T, None, None]:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
yield self.activation_times
self.deactivation_times += 1
async def _asynchronous_function_dependency(self) -> T:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
return self.activation_times
def _synchronous_function_dependency(self) -> T:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
return self.activation_times
def _expect_correct_amount_of_dependency_activations(
*,
app: FastAPI,
dependency_factory: DependencyFactory,
urls_and_responses: List[Tuple[str, Any]],
expected_activation_times: int
) -> None:
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
with TestClient(app) as client:
assert dependency_factory.activation_times == expected_activation_times
assert dependency_factory.deactivation_times == 0
for url, expected_response in urls_and_responses:
response = client.post(url)
response.raise_for_status()
assert response.json() == expected_response
assert dependency_factory.activation_times == expected_activation_times
assert dependency_factory.deactivation_times == 0
assert dependency_factory.activation_times == expected_activation_times
if dependency_factory.dependency_style not in (
DependencyStyle.SYNC_FUNCTION,
DependencyStyle.ASYNC_FUNCTION
):
assert dependency_factory.deactivation_times == expected_activation_times
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, use_cache):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
@router.post("/test")
async def endpoint(
dependency: Annotated[None, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)]
) -> None:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_router_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
if routing_style == "app":
app = FastAPI(dependencies=[depends])
@app.post("/test")
async def endpoint() -> None:
return None
else:
app = FastAPI()
router = APIRouter(dependencies=[depends])
@router.post("/test")
async def endpoint() -> None:
return None
app.include_router(router)
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", None)] * 2,
expected_activation_times=1
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"])
def test_dependency_cache_in_same_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache,
main_dependency_scope: Literal["endpoint", "lifespan"]
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def dependency(
sub_dependency1: Annotated[int, depends],
sub_dependency2: Annotated[int, depends],
) -> List[int]:
return [sub_dependency1, sub_dependency2]
@router.post("/test")
async def endpoint(
dependency: Annotated[List[int], Depends(
dependency,
use_cache=use_cache,
dependency_scope=main_dependency_scope,
)]
) -> List[int]:
return dependency
if routing_style == "router":
app.include_router(router)
if use_cache:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 1]),
("/test", [1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1
)
else:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 2]),
("/test", [1, 2]),
],
dependency_factory=dependency_factory,
expected_activation_times=2
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_same_endpoint(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
@router.post("/test1")
async def endpoint(
dependency1: Annotated[int, depends],
dependency2: Annotated[int, depends],
dependency3: Annotated[int, Depends(endpoint_dependency)]
) -> List[int]:
return [dependency1, dependency2, dependency3]
if routing_style == "router":
app.include_router(router)
if use_cache:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 1, 1]),
("/test1", [1, 1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1
)
else:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 2, 3]),
("/test1", [1, 2, 3]),
],
dependency_factory=dependency_factory,
expected_activation_times=3
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_different_endpoints(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
@router.post("/test1")
async def endpoint(
dependency1: Annotated[int, depends],
dependency2: Annotated[int, depends],
dependency3: Annotated[int, Depends(endpoint_dependency)]
) -> List[int]:
return [dependency1, dependency2, dependency3]
@router.post("/test2")
async def endpoint2(
dependency1: Annotated[int, depends],
dependency2: Annotated[int, depends],
dependency3: Annotated[int, Depends(endpoint_dependency)]
) -> List[int]:
return [dependency1, dependency2, dependency3]
if routing_style == "router":
app.include_router(router)
if use_cache:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 1, 1]),
("/test2", [1, 1, 1]),
("/test1", [1, 1, 1]),
("/test2", [1, 1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1
)
else:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 2, 3]),
("/test2", [4, 5, 3]),
("/test1", [1, 2, 3]),
("/test2", [4, 5, 3]),
],
dependency_factory=dependency_factory,
expected_activation_times=5
)
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_no_cached_dependency(
dependency_style: DependencyStyle,
routing_style,
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=False
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends],
) -> int:
return dependency
if routing_style == "router":
app.include_router(router)
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1
)
@pytest.mark.parametrize("annotation", [
Annotated[str, Path()],
Annotated[str, Body()],
Annotated[str, Query()],
Annotated[str, Header()],
SecurityScopes,
Annotated[str, Cookie()],
Annotated[str, Form()],
Annotated[str, File()],
BackgroundTasks,
])
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
annotation
):
async def dependency_func(param: annotation) -> None:
yield
app = FastAPI()
with pytest.raises(FastAPIError):
@app.post("/test")
async def endpoint(
dependency: Annotated[
None, Depends(dependency_func, dependency_scope="lifespan")]
) -> None:
return
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies(
dependency_style: DependencyStyle
):
dependency_factory = DependencyFactory(dependency_style)
async def lifespan_scoped_dependency(
param: Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)]
) -> AsyncGenerator[int, None]:
yield param
app = FastAPI()
@app.post("/test")
async def endpoint(
dependency: Annotated[int, Depends(
lifespan_scoped_dependency,
dependency_scope="lifespan"
)]
) -> int:
return dependency
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
expected_activation_times=1,
urls_and_responses=[("/test", 1)] * 2
)
@pytest.mark.parametrize("depends_class", [Depends, Security])
@pytest.mark.parametrize("route_type", [FastAPI.post, FastAPI.websocket], ids=[
"websocket", "endpoint"
])
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
depends_class,
route_type
):
async def sub_dependency() -> None:
pass
async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None:
yield
app = FastAPI()
route_decorator = route_type(app, "/test")
with pytest.raises(FastAPIError):
@route_decorator
async def endpoint(x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")]
) -> None:
return
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_dependencies_must_provide_correct_dependency_scope(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
with pytest.raises(FastAPIError):
@router.post("/test")
async def endpoint(
dependency: Annotated[None, Depends(
dependency_factory.get_dependency(),
dependency_scope="incorrect",
use_cache=use_cache,
)]
) -> None:
assert dependency == 1
return dependency
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_incorrect_dependency_scope(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
# We intentionally change the dependency scope here to bypass the
# validation at the function level.
depends.dependency_scope = "asdad"
with pytest.raises(FastAPIError):
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
with TestClient(app) as client:
dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"]
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {}
try:
with pytest.raises(FastAPIError):
client.post("/test")
finally:
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_internal_lifespan(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
with TestClient(app) as client:
internal_state = client.app_state["__fastapi__"]
del client.app_state["__fastapi__"]
try:
with pytest.raises(FastAPIError):
client.post("/test")
finally:
client.app_state["__fastapi__"] = internal_state
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: DependencyStyle, routing_style):
dependency_factory= DependencyFactory(dependency_style, should_error=True)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
with pytest.raises(ValueError) as exception_info:
with TestClient(app):
pass
assert exception_info.value.args == (1,)
# TODO: Add tests for dependency_overrides
# TODO: Add a websocket equivalent to all tests

21
tests/test_params_repr.py

@ -1,5 +1,6 @@
from typing import Any, List from typing import Any, List
import pytest
from dirty_equals import IsOneOf from dirty_equals import IsOneOf
from fastapi.params import Body, Cookie, Depends, Header, Param, Path, Query from fastapi.params import Body, Cookie, Depends, Header, Param, Path, Query
@ -143,10 +144,16 @@ def test_body_repr_list():
assert repr(Body([])) == "Body([])" assert repr(Body([])) == "Body([])"
def test_depends_repr(): @pytest.mark.parametrize(["depends", "expected_repr"], [
assert repr(Depends()) == "Depends(NoneType)" [Depends(), "Depends(NoneType)"],
assert repr(Depends(get_user)) == "Depends(get_user)" [Depends(get_user), "Depends(get_user)"],
assert repr(Depends(use_cache=False)) == "Depends(NoneType, use_cache=False)" [Depends(use_cache=False), "Depends(NoneType, use_cache=False)"],
assert ( [Depends(get_user, use_cache=False), "Depends(get_user, use_cache=False)"],
repr(Depends(get_user, use_cache=False)) == "Depends(get_user, use_cache=False)"
) [Depends(dependency_scope="lifespan"), "Depends(NoneType, dependency_scope=\"lifespan\")"],
[Depends(get_user, dependency_scope="lifespan"), "Depends(get_user, dependency_scope=\"lifespan\")"],
[Depends(use_cache=False, dependency_scope="lifespan"), "Depends(NoneType, use_cache=False, dependency_scope=\"lifespan\")"],
[Depends(get_user, use_cache=False, dependency_scope="lifespan"), "Depends(get_user, use_cache=False, dependency_scope=\"lifespan\")"],
])
def test_depends_repr(depends, expected_repr):
assert repr(depends) == expected_repr

16
tests/test_router_events.py

@ -199,6 +199,9 @@ def test_router_nested_lifespan_state_overriding_by_parent() -> None:
"app_specific": True, "app_specific": True,
"router_specific": True, "router_specific": True,
"overridden": "app", "overridden": "app",
"__fastapi__": {
"lifespan_scoped_dependencies": {}
},
} }
@ -216,7 +219,11 @@ def test_merged_no_return_lifespans_return_none() -> None:
app.include_router(router) app.include_router(router)
with TestClient(app) as client: with TestClient(app) as client:
assert not client.app_state assert client.app_state == {
"__fastapi__": {
"lifespan_scoped_dependencies": {}
}
}
def test_merged_mixed_state_lifespans() -> None: def test_merged_mixed_state_lifespans() -> None:
@ -239,4 +246,9 @@ def test_merged_mixed_state_lifespans() -> None:
app.include_router(router) app.include_router(router)
with TestClient(app) as client: with TestClient(app) as client:
assert client.app_state == {"router": True} assert client.app_state == {
"router": True,
"__fastapi__": {
"lifespan_scoped_dependencies": {}
}
}

Loading…
Cancel
Save