Browse Source

Added support for lifespan-scoped dependencies using a new dependency_scope argument.

pull/12529/head
Nir Schulman 9 months ago
parent
commit
25407d039a
  1. 56
      fastapi/applications.py
  2. 46
      fastapi/dependencies/models.py
  3. 299
      fastapi/dependencies/utils.py
  4. 46
      fastapi/lifespan.py
  5. 6
      fastapi/openapi/utils.py
  6. 32
      fastapi/param_functions.py
  7. 23
      fastapi/params.py
  8. 55
      fastapi/routing.py
  9. 703
      tests/test_lifespan_scoped_dependencies.py
  10. 21
      tests/test_params_repr.py
  11. 16
      tests/test_router_events.py

56
fastapi/applications.py

@ -1,6 +1,8 @@
from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Coroutine,
@ -15,12 +17,14 @@ from typing import (
from fastapi import routing
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.utils import is_coroutine_callable
from fastapi.exception_handlers import (
http_exception_handler,
request_validation_exception_handler,
websocket_request_validation_exception_handler,
)
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.lifespan import resolve_lifespan_dependants
from fastapi.logger import logger
from fastapi.openapi.docs import (
get_redoc_html,
@ -29,9 +33,11 @@ from fastapi.openapi.docs import (
)
from fastapi.openapi.utils import get_openapi
from fastapi.params import Depends
from fastapi.routing import merge_lifespan_context
from fastapi.types import DecoratedCallable, IncEx
from fastapi.utils import generate_unique_id
from starlette.applications import Starlette
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import State
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
@ -929,12 +935,24 @@ class FastAPI(Starlette):
"""
),
] = {}
if lifespan is None:
lifespan = FastAPI._internal_lifespan
else:
lifespan = merge_lifespan_context(
FastAPI._internal_lifespan,
lifespan
)
# Since we always use a lifespan, starlette will no longer run event
# handlers which are defined in the scope of the application.
# We therefore need to call them ourselves.
self._on_startup = on_startup or []
self._on_shutdown = on_shutdown or []
self.router: routing.APIRouter = routing.APIRouter(
routes=routes,
redirect_slashes=redirect_slashes,
dependency_overrides_provider=self,
on_startup=on_startup,
on_shutdown=on_shutdown,
lifespan=lifespan,
default_response_class=default_response_class,
dependencies=dependencies,
@ -963,6 +981,32 @@ class FastAPI(Starlette):
self.middleware_stack: Union[ASGIApp, None] = None
self.setup()
@asynccontextmanager
async def _internal_lifespan(self) -> AsyncGenerator[dict[str, Any], None]:
async with AsyncExitStack() as exit_stack:
lifespan_scoped_dependencies = await resolve_lifespan_dependants(
app=self,
async_exit_stack=exit_stack
)
try:
for handler in self._on_startup:
if is_coroutine_callable(handler):
await handler()
else:
await run_in_threadpool(handler)
yield {
"__fastapi__": {
"lifespan_scoped_dependencies": lifespan_scoped_dependencies
}
}
finally:
for handler in self._on_shutdown:
if is_coroutine_callable(handler):
await handler()
else:
await run_in_threadpool(handler)
def openapi(self) -> Dict[str, Any]:
"""
Generate the OpenAPI schema of the application. This is called by FastAPI
@ -4492,7 +4536,13 @@ class FastAPI(Starlette):
Read more about it in the
[FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/#alternative-events-deprecated).
"""
return self.router.on_event(event_type)
def decorator(func: DecoratedCallable) -> DecoratedCallable:
if event_type == "startup":
self._on_startup.append(func)
else:
self._on_shutdown.append(func)
return func
return decorator
def middleware(
self,

46
fastapi/dependencies/models.py

@ -1,8 +1,9 @@
from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Sequence, Tuple
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase
from typing_extensions import TypeAlias
@dataclass
@ -11,17 +12,41 @@ class SecurityRequirement:
scopes: Optional[Sequence[str]] = None
LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], str], Callable[..., Any]]
@dataclass
class LifespanDependant:
caller: Callable[..., Any]
dependencies: List["LifespanDependant"] = field(default_factory=list)
name: Optional[str] = None
call: Optional[Callable[..., Any]] = None
use_cache: bool = True
cache_key: LifespanDependantCacheKey = field(init=False)
def __post_init__(self) -> None:
if self.use_cache:
self.cache_key = self.call
else:
self.cache_key = (self.caller, self.name)
EndpointDependantCacheKey: TypeAlias = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
@dataclass
class Dependant:
class EndpointDependant:
endpoint_dependencies: List["EndpointDependant"] = field(default_factory=list)
lifespan_dependencies: List[LifespanDependant] = field(default_factory=list)
name: Optional[str] = None
call: Optional[Callable[..., Any]] = None
use_cache: bool = True
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(
init=False)
path_params: List[ModelField] = field(default_factory=list)
query_params: List[ModelField] = field(default_factory=list)
header_params: List[ModelField] = field(default_factory=list)
cookie_params: List[ModelField] = field(default_factory=list)
body_params: List[ModelField] = field(default_factory=list)
dependencies: List["Dependant"] = field(default_factory=list)
security_requirements: List[SecurityRequirement] = field(default_factory=list)
name: Optional[str] = None
call: Optional[Callable[..., Any]] = None
request_param_name: Optional[str] = None
websocket_param_name: Optional[str] = None
http_connection_param_name: Optional[str] = None
@ -29,9 +54,16 @@ class Dependant:
background_tasks_param_name: Optional[str] = None
security_scopes_param_name: Optional[str] = None
security_scopes: Optional[List[str]] = None
use_cache: bool = True
path: Optional[str] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
def __post_init__(self) -> None:
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
# Kept for backwards compatibility
@property
def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]:
return tuple(self.endpoint_dependencies + self.lifespan_dependencies)
# Kept for backwards compatibility
Dependant = EndpointDependant
CacheKey: TypeAlias = Union[EndpointDependantCacheKey, LifespanDependantCacheKey]

299
fastapi/dependencies/utils.py

@ -51,7 +51,14 @@ from fastapi.concurrency import (
asynccontextmanager,
contextmanager_in_threadpool,
)
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.dependencies.models import (
CacheKey,
EndpointDependant,
LifespanDependant,
LifespanDependantCacheKey,
SecurityRequirement,
)
from fastapi.exceptions import FastAPIError
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
@ -112,8 +119,9 @@ def get_param_sub_dependant(
param_name: str,
depends: params.Depends,
path: str,
caller: Callable[..., Any],
security_scopes: Optional[List[str]] = None,
) -> Dependant:
) -> Union[EndpointDependant, LifespanDependant]:
assert depends.dependency
return get_sub_dependant(
depends=depends,
@ -121,14 +129,25 @@ def get_param_sub_dependant(
path=path,
name=param_name,
security_scopes=security_scopes,
caller=caller,
)
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
def get_parameterless_sub_dependant(
*,
depends: params.Depends,
path: str,
caller: Callable[..., Any]
) -> Union[EndpointDependant, LifespanDependant]:
assert callable(
depends.dependency
), "A parameter-less dependency must have a callable dependency"
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
return get_sub_dependant(
depends=depends,
dependency=depends.dependency,
path=path,
caller=caller
)
def get_sub_dependant(
@ -136,57 +155,69 @@ def get_sub_dependant(
depends: params.Depends,
dependency: Callable[..., Any],
path: str,
caller: Callable[..., Any],
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
) -> Dependant:
security_requirement = None
security_scopes = security_scopes or []
if isinstance(depends, params.Security):
dependency_scopes = depends.scopes
security_scopes.extend(dependency_scopes)
if isinstance(dependency, SecurityBase):
use_scopes: List[str] = []
if isinstance(dependency, (OAuth2, OpenIdConnect)):
use_scopes = security_scopes
security_requirement = SecurityRequirement(
security_scheme=dependency, scopes=use_scopes
) -> Union[EndpointDependant, LifespanDependant]:
if depends.dependency_scope == "lifespan":
return get_lifespan_dependant(
caller=caller,
call=depends.dependency,
name=name,
use_cache=depends.use_cache
)
elif depends.dependency_scope == "endpoint":
security_requirement = None
security_scopes = security_scopes or []
if isinstance(depends, params.Security):
dependency_scopes = depends.scopes
security_scopes.extend(dependency_scopes)
if isinstance(dependency, SecurityBase):
use_scopes: List[str] = []
if isinstance(dependency, (OAuth2, OpenIdConnect)):
use_scopes = security_scopes
security_requirement = SecurityRequirement(
security_scheme=dependency, scopes=use_scopes
)
sub_dependant = get_endpoint_dependant(
path=path,
call=dependency,
name=name,
security_scopes=security_scopes,
use_cache=depends.use_cache,
)
if security_requirement:
sub_dependant.security_requirements.append(security_requirement)
return sub_dependant
else:
raise ValueError(
f"Dependency {name} of {caller} has an invalid "
f"sub-dependency scope: {depends.dependency_scope}"
)
sub_dependant = get_dependant(
path=path,
call=dependency,
name=name,
security_scopes=security_scopes,
use_cache=depends.use_cache,
)
if security_requirement:
sub_dependant.security_requirements.append(security_requirement)
return sub_dependant
CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
def get_flat_dependant(
dependant: Dependant,
dependant: EndpointDependant,
*,
skip_repeats: bool = False,
visited: Optional[List[CacheKey]] = None,
) -> Dependant:
) -> EndpointDependant:
if visited is None:
visited = []
visited.append(dependant.cache_key)
flat_dependant = Dependant(
flat_dependant = EndpointDependant(
path_params=dependant.path_params.copy(),
query_params=dependant.query_params.copy(),
header_params=dependant.header_params.copy(),
cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(),
security_requirements=dependant.security_requirements.copy(),
lifespan_dependencies=dependant.lifespan_dependencies.copy(),
use_cache=dependant.use_cache,
path=dependant.path,
path=dependant.path
)
for sub_dependant in dependant.dependencies:
for sub_dependant in dependant.endpoint_dependencies:
if skip_repeats and sub_dependant.cache_key in visited:
continue
flat_sub = get_flat_dependant(
@ -198,6 +229,7 @@ def get_flat_dependant(
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
flat_dependant.body_params.extend(flat_sub.body_params)
flat_dependant.security_requirements.extend(flat_sub.security_requirements)
flat_dependant.lifespan_dependencies.extend(flat_sub.lifespan_dependencies)
return flat_dependant
@ -211,7 +243,7 @@ def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
return fields
def get_flat_params(dependant: Dependant) -> List[ModelField]:
def get_flat_params(dependant: EndpointDependant) -> List[ModelField]:
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
@ -254,18 +286,64 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
return get_typed_annotation(annotation, globalns)
def get_dependant(
def get_lifespan_dependant(
*,
caller: Callable[..., Any],
call: Callable[..., Any],
name: Optional[str] = None,
use_cache: bool = True,
) -> LifespanDependant:
dependency_signature = get_typed_signature(call)
signature_params = dependency_signature.parameters
dependant = LifespanDependant(
call=call,
name=name,
use_cache=use_cache,
caller=caller
)
for param_name, param in signature_params.items():
param_details = analyze_param(
param_name=param_name,
annotation=param.annotation,
value=param.default,
is_path_param=False,
)
if param_details.depends is None:
raise FastAPIError(
f"Lifespan dependency {dependant.name} was defined with an "
f"invalid argument {param_name}. Lifespan dependencies may "
f"only use other lifespan dependencies as arguments.")
if param_details.depends.dependency_scope != "lifespan":
raise FastAPIError(
"Lifespan dependency may not use "
"sub-dependencies of other scopes."
)
sub_dependant = get_lifespan_dependant(
name=param_name,
call=param_details.depends.dependency,
use_cache=param_details.depends.use_cache,
caller=call
)
dependant.dependencies.append(sub_dependant)
return dependant
def get_endpoint_dependant(
*,
path: str,
call: Callable[..., Any],
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
) -> Dependant:
) -> EndpointDependant:
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters
dependant = Dependant(
dependant = EndpointDependant(
call=call,
name=name,
path=path,
@ -281,13 +359,28 @@ def get_dependant(
is_path_param=is_path_param,
)
if param_details.depends is not None:
sub_dependant = get_param_sub_dependant(
param_name=param_name,
depends=param_details.depends,
path=path,
security_scopes=security_scopes,
)
dependant.dependencies.append(sub_dependant)
if param_details.depends.dependency_scope == "endpoint":
sub_dependant = get_param_sub_dependant(
param_name=param_name,
depends=param_details.depends,
path=path,
security_scopes=security_scopes,
caller=call,
)
dependant.endpoint_dependencies.append(sub_dependant)
elif param_details.depends.dependency_scope == "lifespan":
sub_dependant = get_lifespan_dependant(
caller=call,
call=param_details.depends.dependency,
name=param_name,
use_cache=param_details.depends.use_cache,
)
dependant.lifespan_dependencies.append(sub_dependant)
else:
raise FastAPIError(
f"Dependency \"{param_name}\" of `{call}` has an invalid "
f"sub-dependency scope: \"{param_details.depends.dependency_scope}\""
)
continue
if add_non_field_param_to_dependency(
param_name=param_name,
@ -306,8 +399,12 @@ def get_dependant(
return dependant
# Kept for backwards compatibility
get_dependant = get_endpoint_dependant
def add_non_field_param_to_dependency(
*, param_name: str, type_annotation: Any, dependant: Dependant
*, param_name: str, type_annotation: Any, dependant: EndpointDependant
) -> Optional[bool]:
if lenient_issubclass(type_annotation, Request):
dependant.request_param_name = param_name
@ -501,7 +598,7 @@ def analyze_param(
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
def add_param_to_fields(*, field: ModelField, dependant: EndpointDependant) -> None:
field_info = field.field_info
field_info_in = getattr(field_info, "in_", None)
if field_info_in == params.ParamTypes.path:
@ -550,6 +647,82 @@ async def solve_generator(
return await stack.enter_async_context(cm)
@dataclass
class SolvedLifespanDependant:
value: Any
dependency_cache: Dict[Callable[..., Any], Any]
async def solve_lifespan_dependant(
*,
dependant: LifespanDependant,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[LifespanDependantCacheKey, Callable[..., Any]]] = None,
async_exit_stack: AsyncExitStack,
) -> SolvedLifespanDependant:
dependency_cache = dependency_cache or {}
if dependant.use_cache and dependant.cache_key in dependency_cache:
return SolvedLifespanDependant(
value=dependency_cache[dependant.cache_key],
dependency_cache=dependency_cache,
)
dependency_arguments: Dict[str, Any] = {}
sub_dependant: LifespanDependant
for sub_dependant in dependant.dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast(
Callable[..., Any], sub_dependant.cache_key
)
assert sub_dependant.name, (
"Lifespan scoped dependencies should not be able to have "
"subdependencies with no name"
)
sub_dependant_to_solve = sub_dependant
if (
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
):
original_call = sub_dependant.call
call = getattr(
dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call)
sub_dependant_to_solve = get_lifespan_dependant(
call=call,
name=sub_dependant.name,
caller=dependant.call
)
solved_sub_dependant = await solve_lifespan_dependant(
dependant=sub_dependant_to_solve,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack,
)
dependency_cache.update(solved_sub_dependant.dependency_cache)
dependency_arguments[sub_dependant.name] = solved_sub_dependant.value
if is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call):
value = await solve_generator(
call=dependant.call,
stack=async_exit_stack,
sub_values=dependency_arguments
)
elif is_coroutine_callable(dependant.call):
value = await dependant.call(**dependency_arguments)
else:
value = await run_in_threadpool(dependant.call, **dependency_arguments)
if dependant.cache_key not in dependency_cache:
dependency_cache[dependant.cache_key] = value
return SolvedLifespanDependant(
value=value,
dependency_cache=dependency_cache,
)
@dataclass
class SolvedDependency:
values: Dict[str, Any]
@ -562,7 +735,7 @@ class SolvedDependency:
async def solve_dependencies(
*,
request: Union[Request, WebSocket],
dependant: Dependant,
dependant: EndpointDependant,
body: Optional[Union[Dict[str, Any], FormData]] = None,
background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None,
@ -573,13 +746,35 @@ async def solve_dependencies(
) -> SolvedDependency:
values: Dict[str, Any] = {}
errors: List[Any] = []
for sub_dependant in dependant.lifespan_dependencies:
if sub_dependant.name is None:
continue
try:
lifespan_scoped_dependencies = request.state.__fastapi__[
"lifespan_scoped_dependencies"]
except AttributeError as e:
raise FastAPIError(
"FastAPI's internal lifespan was not initialized"
) from e
try:
value = lifespan_scoped_dependencies[sub_dependant.cache_key]
except KeyError as e:
raise FastAPIError(
f"Dependency {sub_dependant.name} of {dependant.call} "
f"was not initialized."
) from e
values[sub_dependant.name] = value
if response is None:
response = Response()
del response.headers["content-length"]
response.status_code = None # type: ignore
dependency_cache = dependency_cache or {}
sub_dependant: Dependant
for sub_dependant in dependant.dependencies:
for sub_dependant in dependant.endpoint_dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast(
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
@ -595,7 +790,7 @@ async def solve_dependencies(
dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call)
use_path: str = sub_dependant.path # type: ignore
use_sub_dependant = get_dependant(
use_sub_dependant = get_endpoint_dependant(
path=use_path,
call=call,
name=sub_dependant.name,
@ -910,7 +1105,7 @@ async def request_body_to_args(
def get_body_field(
*, flat_dependant: Dependant, name: str, embed_body_fields: bool
*, flat_dependant: EndpointDependant, name: str, embed_body_fields: bool
) -> Optional[ModelField]:
"""
Get a ModelField representing the request body for a path operation, combining

46
fastapi/lifespan.py

@ -0,0 +1,46 @@
from __future__ import annotations
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Callable, Dict, List
from fastapi.dependencies.models import LifespanDependant, LifespanDependantCacheKey
from fastapi.dependencies.utils import solve_lifespan_dependant
from fastapi.routing import APIRoute
if TYPE_CHECKING:
from fastapi import FastAPI
def _get_lifespan_dependants(app: FastAPI) -> List[LifespanDependant]:
lifespan_dependants_cache: Dict[LifespanDependantCacheKey, LifespanDependant] = {}
for route in app.router.routes:
if not isinstance(route, APIRoute):
continue
for sub_dependant in route.lifespan_dependencies:
if sub_dependant.cache_key in lifespan_dependants_cache:
continue
lifespan_dependants_cache[sub_dependant.cache_key] = sub_dependant
return list(lifespan_dependants_cache.values())
async def resolve_lifespan_dependants(
*,
app: FastAPI,
async_exit_stack: AsyncExitStack
) -> Dict[LifespanDependantCacheKey, Callable[..., Any]]:
lifespan_dependants = _get_lifespan_dependants(app)
dependency_cache: Dict[LifespanDependantCacheKey, Callable[..., Any]] = {}
for lifespan_dependant in lifespan_dependants:
solved_dependency = await solve_lifespan_dependant(
dependant=lifespan_dependant,
dependency_overrides_provider=app,
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack
)
dependency_cache.update(solved_dependency.dependency_cache)
return dependency_cache

6
fastapi/openapi/utils.py

@ -15,7 +15,7 @@ from fastapi._compat import (
lenient_issubclass,
)
from fastapi.datastructures import DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.models import EndpointDependant
from fastapi.dependencies.utils import (
_get_flat_fields_from_params,
get_flat_dependant,
@ -75,7 +75,7 @@ status_code_ranges: Dict[str, str] = {
def get_openapi_security_definitions(
flat_dependant: Dependant,
flat_dependant: EndpointDependant,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
security_definitions = {}
operation_security = []
@ -93,7 +93,7 @@ def get_openapi_security_definitions(
def _get_openapi_operation_parameters(
*,
dependant: Dependant,
dependant: EndpointDependant,
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
field_mapping: Dict[

32
fastapi/param_functions.py

@ -1,8 +1,11 @@
from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from fastapi import params
from fastapi._compat import Undefined
from fastapi.openapi.models import Example
from fastapi.params import DependencyScope
from typing_extensions import Annotated, Doc, deprecated
_Unset: Any = Undefined
@ -2244,6 +2247,33 @@ def Depends( # noqa: N802
"""
),
] = True,
dependency_scope: Annotated[
DependencyScope,
Doc(
"""
The scope in which the dependency value should be evaluated. Can be
either `"endpoint"` or `"lifespan"`.
If `dependency_scope` is set to "endpoint" (the default), the
dependency will be setup and teardown for every request.
If `dependency_scope` is set to `"lifespan"` the dependency would
be setup at the start of the entire application's lifespan. The
evaluated dependency would be then reused across all endpoints.
The dependency would be teared down as a part of the application's
shutdown process.
Note that dependencies defined with the `"endpoint"` scope may use
sub-dependencies defined with the `"lifespan"` scope, but not the
other way around;
Dependencies defined with the `"lifespan"` scope may not use
sub-dependencies with `"endpoint"` scope, nor can they use
other "endpoint scoped" arguments such as "Path", "Body", "Query",
or any other annotation which does not make sense in a scope of an
application's entire lifespan.
"""
)
] = "endpoint"
) -> Any:
"""
Declare a FastAPI dependency.
@ -2274,7 +2304,7 @@ def Depends( # noqa: N802
return commons
```
"""
return params.Depends(dependency=dependency, use_cache=use_cache)
return params.Depends(dependency=dependency, use_cache=use_cache, dependency_scope=dependency_scope)
def Security( # noqa: N802

23
fastapi/params.py

@ -4,11 +4,12 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from fastapi.openapi.models import Example
from pydantic.fields import FieldInfo
from typing_extensions import Annotated, deprecated
from typing_extensions import Annotated, Literal, TypeAlias, deprecated
from ._compat import PYDANTIC_V2, PYDANTIC_VERSION, Undefined
_Unset: Any = Undefined
DependencyScope: TypeAlias = Literal["endpoint", "lifespan"]
class ParamTypes(Enum):
@ -759,15 +760,25 @@ class File(Form):
class Depends:
def __init__(
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
self,
dependency: Optional[Callable[..., Any]] = None,
*,
use_cache: bool = True,
dependency_scope: DependencyScope = "endpoint"
):
self.dependency = dependency
self.use_cache = use_cache
self.dependency_scope = dependency_scope
def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({attr}{cache})"
if self.dependency_scope == "endpoint":
dependency_scope = ""
else:
dependency_scope = f", dependency_scope=\"{self.dependency_scope}\""
return f"{self.__class__.__name__}({attr}{cache}{dependency_scope})"
class Security(Depends):
@ -778,5 +789,9 @@ class Security(Depends):
scopes: Optional[Sequence[str]] = None,
use_cache: bool = True,
):
super().__init__(dependency=dependency, use_cache=use_cache)
super().__init__(
dependency=dependency,
use_cache=use_cache,
dependency_scope="endpoint"
)
self.scopes = scopes or []

55
fastapi/routing.py

@ -31,11 +31,11 @@ from fastapi._compat import (
lenient_issubclass,
)
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.models import EndpointDependant, LifespanDependant
from fastapi.dependencies.utils import (
_should_embed_body_fields,
get_body_field,
get_dependant,
get_endpoint_dependant,
get_flat_dependant,
get_parameterless_sub_dependant,
get_typed_return_annotation,
@ -73,7 +73,7 @@ from starlette.routing import (
from starlette.routing import Mount as Mount # noqa
from starlette.types import AppType, ASGIApp, Lifespan, Scope
from starlette.websockets import WebSocket
from typing_extensions import Annotated, Doc, deprecated
from typing_extensions import Annotated, Doc, assert_never, deprecated
def _prepare_response_content(
@ -123,7 +123,7 @@ def _prepare_response_content(
return res
def _merge_lifespan_context(
def merge_lifespan_context(
original_context: Lifespan[Any], nested_context: Lifespan[Any]
) -> Lifespan[Any]:
@asynccontextmanager
@ -202,7 +202,7 @@ async def serialize_response(
async def run_endpoint_function(
*, dependant: Dependant, values: Dict[str, Any], is_coroutine: bool
*, dependant: EndpointDependant, values: Dict[str, Any], is_coroutine: bool
) -> Any:
# Only called by get_request_handler. Has been split into its own function to
# facilitate profiling endpoints, since inner functions are harder to profile.
@ -215,7 +215,7 @@ async def run_endpoint_function(
def get_request_handler(
dependant: Dependant,
dependant: EndpointDependant,
body_field: Optional[ModelField] = None,
status_code: Optional[int] = None,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse),
@ -358,7 +358,7 @@ def get_request_handler(
def get_websocket_app(
dependant: Dependant,
dependant: EndpointDependant,
dependency_overrides_provider: Optional[Any] = None,
embed_body_fields: bool = False,
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
@ -400,12 +400,20 @@ class APIWebSocketRoute(routing.WebSocketRoute):
self.name = get_name(endpoint) if name is None else name
self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
sub_dependant = get_parameterless_sub_dependant(
depends=depends,
path=self.path_format,
caller=self
)
if depends.dependency_scope == "endpoint":
self.dependant.endpoint_dependencies.insert(0, sub_dependant)
elif depends.dependency_scope == "lifespan":
self.dependant.lifespan_dependencies.insert(0, sub_dependant)
else:
assert_never(depends.dependency_scope)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params
@ -424,6 +432,10 @@ class APIWebSocketRoute(routing.WebSocketRoute):
child_scope["route"] = self
return match, child_scope
@property
def lifespan_dependencies(self) -> List[LifespanDependant]:
return self._flat_dependant.lifespan_dependencies
class APIRoute(routing.Route):
def __init__(
@ -549,12 +561,19 @@ class APIRoute(routing.Route):
self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable"
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
sub_dependant = get_parameterless_sub_dependant(
depends=depends,
path=self.path_format,
caller=self.__call__
)
if depends.dependency_scope == "endpoint":
self.dependant.endpoint_dependencies.insert(0, sub_dependant)
elif depends.dependency_scope == "lifespan":
self.dependant.lifespan_dependencies.insert(0, sub_dependant)
else:
assert_never(depends.dependency_scope)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params
@ -589,6 +608,10 @@ class APIRoute(routing.Route):
child_scope["route"] = self
return match, child_scope
@property
def lifespan_dependencies(self) -> List[LifespanDependant]:
return self._flat_dependant.lifespan_dependencies
class APIRouter(routing.Router):
"""
@ -1356,7 +1379,7 @@ class APIRouter(routing.Router):
self.add_event_handler("startup", handler)
for handler in router.on_shutdown:
self.add_event_handler("shutdown", handler)
self.lifespan_context = _merge_lifespan_context(
self.lifespan_context = merge_lifespan_context(
self.lifespan_context,
router.lifespan_context,
)

703
tests/test_lifespan_scoped_dependencies.py

@ -0,0 +1,703 @@
from enum import StrEnum, auto
from typing import Any, AsyncGenerator, List, Tuple, TypeVar
import pytest
from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Cookie,
Depends,
FastAPI,
File,
Form,
Header,
Path,
Query,
)
from fastapi.exceptions import FastAPIError
from fastapi.params import Security
from fastapi.security import SecurityScopes
from starlette.testclient import TestClient
from typing_extensions import Annotated, Generator, Literal, assert_never
T = TypeVar('T')
class DependencyStyle(StrEnum):
SYNC_FUNCTION = auto()
ASYNC_FUNCTION = auto()
SYNC_GENERATOR = auto()
ASYNC_GENERATOR = auto()
class DependencyFactory:
def __init__(
self,
dependency_style: DependencyStyle, *,
should_error: bool = False
):
self.activation_times = 0
self.deactivation_times = 0
self.dependency_style = dependency_style
self._should_error = should_error
def get_dependency(self):
if self.dependency_style == DependencyStyle.SYNC_FUNCTION:
return self._synchronous_function_dependency
if self.dependency_style == DependencyStyle.SYNC_GENERATOR:
return self._synchronous_generator_dependency
if self.dependency_style == DependencyStyle.ASYNC_FUNCTION:
return self._asynchronous_function_dependency
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR:
return self._asynchronous_generator_dependency
assert_never(self.dependency_style)
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
yield self.activation_times
self.deactivation_times += 1
def _synchronous_generator_dependency(self) -> Generator[T, None, None]:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
yield self.activation_times
self.deactivation_times += 1
async def _asynchronous_function_dependency(self) -> T:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
return self.activation_times
def _synchronous_function_dependency(self) -> T:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
return self.activation_times
def _expect_correct_amount_of_dependency_activations(
*,
app: FastAPI,
dependency_factory: DependencyFactory,
urls_and_responses: List[Tuple[str, Any]],
expected_activation_times: int
) -> None:
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
with TestClient(app) as client:
assert dependency_factory.activation_times == expected_activation_times
assert dependency_factory.deactivation_times == 0
for url, expected_response in urls_and_responses:
response = client.post(url)
response.raise_for_status()
assert response.json() == expected_response
assert dependency_factory.activation_times == expected_activation_times
assert dependency_factory.deactivation_times == 0
assert dependency_factory.activation_times == expected_activation_times
if dependency_factory.dependency_style not in (
DependencyStyle.SYNC_FUNCTION,
DependencyStyle.ASYNC_FUNCTION
):
assert dependency_factory.deactivation_times == expected_activation_times
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, use_cache):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
@router.post("/test")
async def endpoint(
dependency: Annotated[None, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)]
) -> None:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_router_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
if routing_style == "app":
app = FastAPI(dependencies=[depends])
@app.post("/test")
async def endpoint() -> None:
return None
else:
app = FastAPI()
router = APIRouter(dependencies=[depends])
@router.post("/test")
async def endpoint() -> None:
return None
app.include_router(router)
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", None)] * 2,
expected_activation_times=1
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"])
def test_dependency_cache_in_same_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache,
main_dependency_scope: Literal["endpoint", "lifespan"]
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def dependency(
sub_dependency1: Annotated[int, depends],
sub_dependency2: Annotated[int, depends],
) -> List[int]:
return [sub_dependency1, sub_dependency2]
@router.post("/test")
async def endpoint(
dependency: Annotated[List[int], Depends(
dependency,
use_cache=use_cache,
dependency_scope=main_dependency_scope,
)]
) -> List[int]:
return dependency
if routing_style == "router":
app.include_router(router)
if use_cache:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 1]),
("/test", [1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1
)
else:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 2]),
("/test", [1, 2]),
],
dependency_factory=dependency_factory,
expected_activation_times=2
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_same_endpoint(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
@router.post("/test1")
async def endpoint(
dependency1: Annotated[int, depends],
dependency2: Annotated[int, depends],
dependency3: Annotated[int, Depends(endpoint_dependency)]
) -> List[int]:
return [dependency1, dependency2, dependency3]
if routing_style == "router":
app.include_router(router)
if use_cache:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 1, 1]),
("/test1", [1, 1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1
)
else:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 2, 3]),
("/test1", [1, 2, 3]),
],
dependency_factory=dependency_factory,
expected_activation_times=3
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_different_endpoints(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
@router.post("/test1")
async def endpoint(
dependency1: Annotated[int, depends],
dependency2: Annotated[int, depends],
dependency3: Annotated[int, Depends(endpoint_dependency)]
) -> List[int]:
return [dependency1, dependency2, dependency3]
@router.post("/test2")
async def endpoint2(
dependency1: Annotated[int, depends],
dependency2: Annotated[int, depends],
dependency3: Annotated[int, Depends(endpoint_dependency)]
) -> List[int]:
return [dependency1, dependency2, dependency3]
if routing_style == "router":
app.include_router(router)
if use_cache:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 1, 1]),
("/test2", [1, 1, 1]),
("/test1", [1, 1, 1]),
("/test2", [1, 1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1
)
else:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 2, 3]),
("/test2", [4, 5, 3]),
("/test1", [1, 2, 3]),
("/test2", [4, 5, 3]),
],
dependency_factory=dependency_factory,
expected_activation_times=5
)
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_no_cached_dependency(
dependency_style: DependencyStyle,
routing_style,
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=False
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends],
) -> int:
return dependency
if routing_style == "router":
app.include_router(router)
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1
)
@pytest.mark.parametrize("annotation", [
Annotated[str, Path()],
Annotated[str, Body()],
Annotated[str, Query()],
Annotated[str, Header()],
SecurityScopes,
Annotated[str, Cookie()],
Annotated[str, Form()],
Annotated[str, File()],
BackgroundTasks,
])
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
annotation
):
async def dependency_func(param: annotation) -> None:
yield
app = FastAPI()
with pytest.raises(FastAPIError):
@app.post("/test")
async def endpoint(
dependency: Annotated[
None, Depends(dependency_func, dependency_scope="lifespan")]
) -> None:
return
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies(
dependency_style: DependencyStyle
):
dependency_factory = DependencyFactory(dependency_style)
async def lifespan_scoped_dependency(
param: Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)]
) -> AsyncGenerator[int, None]:
yield param
app = FastAPI()
@app.post("/test")
async def endpoint(
dependency: Annotated[int, Depends(
lifespan_scoped_dependency,
dependency_scope="lifespan"
)]
) -> int:
return dependency
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
expected_activation_times=1,
urls_and_responses=[("/test", 1)] * 2
)
@pytest.mark.parametrize("depends_class", [Depends, Security])
@pytest.mark.parametrize("route_type", [FastAPI.post, FastAPI.websocket], ids=[
"websocket", "endpoint"
])
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
depends_class,
route_type
):
async def sub_dependency() -> None:
pass
async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None:
yield
app = FastAPI()
route_decorator = route_type(app, "/test")
with pytest.raises(FastAPIError):
@route_decorator
async def endpoint(x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")]
) -> None:
return
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_dependencies_must_provide_correct_dependency_scope(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
with pytest.raises(FastAPIError):
@router.post("/test")
async def endpoint(
dependency: Annotated[None, Depends(
dependency_factory.get_dependency(),
dependency_scope="incorrect",
use_cache=use_cache,
)]
) -> None:
assert dependency == 1
return dependency
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_incorrect_dependency_scope(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
# We intentionally change the dependency scope here to bypass the
# validation at the function level.
depends.dependency_scope = "asdad"
with pytest.raises(FastAPIError):
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
with TestClient(app) as client:
dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"]
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {}
try:
with pytest.raises(FastAPIError):
client.post("/test")
finally:
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_internal_lifespan(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
with TestClient(app) as client:
internal_state = client.app_state["__fastapi__"]
del client.app_state["__fastapi__"]
try:
with pytest.raises(FastAPIError):
client.post("/test")
finally:
client.app_state["__fastapi__"] = internal_state
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: DependencyStyle, routing_style):
dependency_factory= DependencyFactory(dependency_style, should_error=True)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
with pytest.raises(ValueError) as exception_info:
with TestClient(app):
pass
assert exception_info.value.args == (1,)
# TODO: Add tests for dependency_overrides
# TODO: Add a websocket equivalent to all tests

21
tests/test_params_repr.py

@ -1,5 +1,6 @@
from typing import Any, List
import pytest
from dirty_equals import IsOneOf
from fastapi.params import Body, Cookie, Depends, Header, Param, Path, Query
@ -143,10 +144,16 @@ def test_body_repr_list():
assert repr(Body([])) == "Body([])"
def test_depends_repr():
assert repr(Depends()) == "Depends(NoneType)"
assert repr(Depends(get_user)) == "Depends(get_user)"
assert repr(Depends(use_cache=False)) == "Depends(NoneType, use_cache=False)"
assert (
repr(Depends(get_user, use_cache=False)) == "Depends(get_user, use_cache=False)"
)
@pytest.mark.parametrize(["depends", "expected_repr"], [
[Depends(), "Depends(NoneType)"],
[Depends(get_user), "Depends(get_user)"],
[Depends(use_cache=False), "Depends(NoneType, use_cache=False)"],
[Depends(get_user, use_cache=False), "Depends(get_user, use_cache=False)"],
[Depends(dependency_scope="lifespan"), "Depends(NoneType, dependency_scope=\"lifespan\")"],
[Depends(get_user, dependency_scope="lifespan"), "Depends(get_user, dependency_scope=\"lifespan\")"],
[Depends(use_cache=False, dependency_scope="lifespan"), "Depends(NoneType, use_cache=False, dependency_scope=\"lifespan\")"],
[Depends(get_user, use_cache=False, dependency_scope="lifespan"), "Depends(get_user, use_cache=False, dependency_scope=\"lifespan\")"],
])
def test_depends_repr(depends, expected_repr):
assert repr(depends) == expected_repr

16
tests/test_router_events.py

@ -199,6 +199,9 @@ def test_router_nested_lifespan_state_overriding_by_parent() -> None:
"app_specific": True,
"router_specific": True,
"overridden": "app",
"__fastapi__": {
"lifespan_scoped_dependencies": {}
},
}
@ -216,7 +219,11 @@ def test_merged_no_return_lifespans_return_none() -> None:
app.include_router(router)
with TestClient(app) as client:
assert not client.app_state
assert client.app_state == {
"__fastapi__": {
"lifespan_scoped_dependencies": {}
}
}
def test_merged_mixed_state_lifespans() -> None:
@ -239,4 +246,9 @@ def test_merged_mixed_state_lifespans() -> None:
app.include_router(router)
with TestClient(app) as client:
assert client.app_state == {"router": True}
assert client.app_state == {
"router": True,
"__fastapi__": {
"lifespan_scoped_dependencies": {}
}
}

Loading…
Cancel
Save