From 25407d039a9742daef3967a301d8f2aa23c442d5 Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Thu, 24 Oct 2024 15:30:38 +0300 Subject: [PATCH 01/29] Added support for lifespan-scoped dependencies using a new dependency_scope argument. --- fastapi/applications.py | 56 +- fastapi/dependencies/models.py | 46 +- fastapi/dependencies/utils.py | 299 +++++++-- fastapi/lifespan.py | 46 ++ fastapi/openapi/utils.py | 6 +- fastapi/param_functions.py | 32 +- fastapi/params.py | 23 +- fastapi/routing.py | 55 +- tests/test_lifespan_scoped_dependencies.py | 703 +++++++++++++++++++++ tests/test_params_repr.py | 21 +- tests/test_router_events.py | 16 +- 11 files changed, 1208 insertions(+), 95 deletions(-) create mode 100644 fastapi/lifespan.py create mode 100644 tests/test_lifespan_scoped_dependencies.py diff --git a/fastapi/applications.py b/fastapi/applications.py index 6d427cdc2..625690a74 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -1,6 +1,8 @@ +from contextlib import AsyncExitStack, asynccontextmanager from enum import Enum from typing import ( Any, + AsyncGenerator, Awaitable, Callable, Coroutine, @@ -15,12 +17,14 @@ from typing import ( from fastapi import routing from fastapi.datastructures import Default, DefaultPlaceholder +from fastapi.dependencies.utils import is_coroutine_callable from fastapi.exception_handlers import ( http_exception_handler, request_validation_exception_handler, websocket_request_validation_exception_handler, ) from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError +from fastapi.lifespan import resolve_lifespan_dependants from fastapi.logger import logger from fastapi.openapi.docs import ( get_redoc_html, @@ -29,9 +33,11 @@ from fastapi.openapi.docs import ( ) from fastapi.openapi.utils import get_openapi from fastapi.params import Depends +from fastapi.routing import merge_lifespan_context from fastapi.types import DecoratedCallable, IncEx from fastapi.utils import generate_unique_id from starlette.applications import Starlette +from starlette.concurrency import run_in_threadpool from starlette.datastructures import State from starlette.exceptions import HTTPException from starlette.middleware import Middleware @@ -929,12 +935,24 @@ class FastAPI(Starlette): """ ), ] = {} + if lifespan is None: + lifespan = FastAPI._internal_lifespan + else: + lifespan = merge_lifespan_context( + FastAPI._internal_lifespan, + lifespan + ) + + # Since we always use a lifespan, starlette will no longer run event + # handlers which are defined in the scope of the application. + # We therefore need to call them ourselves. + self._on_startup = on_startup or [] + self._on_shutdown = on_shutdown or [] + self.router: routing.APIRouter = routing.APIRouter( routes=routes, redirect_slashes=redirect_slashes, dependency_overrides_provider=self, - on_startup=on_startup, - on_shutdown=on_shutdown, lifespan=lifespan, default_response_class=default_response_class, dependencies=dependencies, @@ -963,6 +981,32 @@ class FastAPI(Starlette): self.middleware_stack: Union[ASGIApp, None] = None self.setup() + @asynccontextmanager + async def _internal_lifespan(self) -> AsyncGenerator[dict[str, Any], None]: + async with AsyncExitStack() as exit_stack: + lifespan_scoped_dependencies = await resolve_lifespan_dependants( + app=self, + async_exit_stack=exit_stack + ) + try: + for handler in self._on_startup: + if is_coroutine_callable(handler): + await handler() + else: + await run_in_threadpool(handler) + yield { + "__fastapi__": { + "lifespan_scoped_dependencies": lifespan_scoped_dependencies + } + } + finally: + for handler in self._on_shutdown: + if is_coroutine_callable(handler): + await handler() + else: + await run_in_threadpool(handler) + + def openapi(self) -> Dict[str, Any]: """ Generate the OpenAPI schema of the application. This is called by FastAPI @@ -4492,7 +4536,13 @@ class FastAPI(Starlette): Read more about it in the [FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/#alternative-events-deprecated). """ - return self.router.on_event(event_type) + def decorator(func: DecoratedCallable) -> DecoratedCallable: + if event_type == "startup": + self._on_startup.append(func) + else: + self._on_shutdown.append(func) + return func + return decorator def middleware( self, diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 418c11725..471f9c402 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -1,8 +1,9 @@ from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from fastapi._compat import ModelField from fastapi.security.base import SecurityBase +from typing_extensions import TypeAlias @dataclass @@ -11,17 +12,41 @@ class SecurityRequirement: scopes: Optional[Sequence[str]] = None +LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], str], Callable[..., Any]] + +@dataclass +class LifespanDependant: + caller: Callable[..., Any] + dependencies: List["LifespanDependant"] = field(default_factory=list) + name: Optional[str] = None + call: Optional[Callable[..., Any]] = None + use_cache: bool = True + cache_key: LifespanDependantCacheKey = field(init=False) + + def __post_init__(self) -> None: + if self.use_cache: + self.cache_key = self.call + else: + self.cache_key = (self.caller, self.name) + + +EndpointDependantCacheKey: TypeAlias = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] + @dataclass -class Dependant: +class EndpointDependant: + endpoint_dependencies: List["EndpointDependant"] = field(default_factory=list) + lifespan_dependencies: List[LifespanDependant] = field(default_factory=list) + name: Optional[str] = None + call: Optional[Callable[..., Any]] = None + use_cache: bool = True + cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field( + init=False) path_params: List[ModelField] = field(default_factory=list) query_params: List[ModelField] = field(default_factory=list) header_params: List[ModelField] = field(default_factory=list) cookie_params: List[ModelField] = field(default_factory=list) body_params: List[ModelField] = field(default_factory=list) - dependencies: List["Dependant"] = field(default_factory=list) security_requirements: List[SecurityRequirement] = field(default_factory=list) - name: Optional[str] = None - call: Optional[Callable[..., Any]] = None request_param_name: Optional[str] = None websocket_param_name: Optional[str] = None http_connection_param_name: Optional[str] = None @@ -29,9 +54,16 @@ class Dependant: background_tasks_param_name: Optional[str] = None security_scopes_param_name: Optional[str] = None security_scopes: Optional[List[str]] = None - use_cache: bool = True path: Optional[str] = None - cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) def __post_init__(self) -> None: self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) + + # Kept for backwards compatibility + @property + def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]: + return tuple(self.endpoint_dependencies + self.lifespan_dependencies) + +# Kept for backwards compatibility +Dependant = EndpointDependant +CacheKey: TypeAlias = Union[EndpointDependantCacheKey, LifespanDependantCacheKey] diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 5cebbf00f..c58232b6d 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -51,7 +51,14 @@ from fastapi.concurrency import ( asynccontextmanager, contextmanager_in_threadpool, ) -from fastapi.dependencies.models import Dependant, SecurityRequirement +from fastapi.dependencies.models import ( + CacheKey, + EndpointDependant, + LifespanDependant, + LifespanDependantCacheKey, + SecurityRequirement, +) +from fastapi.exceptions import FastAPIError from fastapi.logger import logger from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes @@ -112,8 +119,9 @@ def get_param_sub_dependant( param_name: str, depends: params.Depends, path: str, + caller: Callable[..., Any], security_scopes: Optional[List[str]] = None, -) -> Dependant: +) -> Union[EndpointDependant, LifespanDependant]: assert depends.dependency return get_sub_dependant( depends=depends, @@ -121,14 +129,25 @@ def get_param_sub_dependant( path=path, name=param_name, security_scopes=security_scopes, + caller=caller, ) -def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: +def get_parameterless_sub_dependant( + *, + depends: params.Depends, + path: str, + caller: Callable[..., Any] +) -> Union[EndpointDependant, LifespanDependant]: assert callable( depends.dependency ), "A parameter-less dependency must have a callable dependency" - return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) + return get_sub_dependant( + depends=depends, + dependency=depends.dependency, + path=path, + caller=caller + ) def get_sub_dependant( @@ -136,57 +155,69 @@ def get_sub_dependant( depends: params.Depends, dependency: Callable[..., Any], path: str, + caller: Callable[..., Any], name: Optional[str] = None, security_scopes: Optional[List[str]] = None, -) -> Dependant: - security_requirement = None - security_scopes = security_scopes or [] - if isinstance(depends, params.Security): - dependency_scopes = depends.scopes - security_scopes.extend(dependency_scopes) - if isinstance(dependency, SecurityBase): - use_scopes: List[str] = [] - if isinstance(dependency, (OAuth2, OpenIdConnect)): - use_scopes = security_scopes - security_requirement = SecurityRequirement( - security_scheme=dependency, scopes=use_scopes +) -> Union[EndpointDependant, LifespanDependant]: + if depends.dependency_scope == "lifespan": + return get_lifespan_dependant( + caller=caller, + call=depends.dependency, + name=name, + use_cache=depends.use_cache + ) + elif depends.dependency_scope == "endpoint": + security_requirement = None + security_scopes = security_scopes or [] + if isinstance(depends, params.Security): + dependency_scopes = depends.scopes + security_scopes.extend(dependency_scopes) + if isinstance(dependency, SecurityBase): + use_scopes: List[str] = [] + if isinstance(dependency, (OAuth2, OpenIdConnect)): + use_scopes = security_scopes + security_requirement = SecurityRequirement( + security_scheme=dependency, scopes=use_scopes + ) + sub_dependant = get_endpoint_dependant( + path=path, + call=dependency, + name=name, + security_scopes=security_scopes, + use_cache=depends.use_cache, + ) + if security_requirement: + sub_dependant.security_requirements.append(security_requirement) + return sub_dependant + else: + raise ValueError( + f"Dependency {name} of {caller} has an invalid " + f"sub-dependency scope: {depends.dependency_scope}" ) - sub_dependant = get_dependant( - path=path, - call=dependency, - name=name, - security_scopes=security_scopes, - use_cache=depends.use_cache, - ) - if security_requirement: - sub_dependant.security_requirements.append(security_requirement) - return sub_dependant - - -CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] def get_flat_dependant( - dependant: Dependant, + dependant: EndpointDependant, *, skip_repeats: bool = False, visited: Optional[List[CacheKey]] = None, -) -> Dependant: +) -> EndpointDependant: if visited is None: visited = [] visited.append(dependant.cache_key) - flat_dependant = Dependant( + flat_dependant = EndpointDependant( path_params=dependant.path_params.copy(), query_params=dependant.query_params.copy(), header_params=dependant.header_params.copy(), cookie_params=dependant.cookie_params.copy(), body_params=dependant.body_params.copy(), security_requirements=dependant.security_requirements.copy(), + lifespan_dependencies=dependant.lifespan_dependencies.copy(), use_cache=dependant.use_cache, - path=dependant.path, + path=dependant.path ) - for sub_dependant in dependant.dependencies: + for sub_dependant in dependant.endpoint_dependencies: if skip_repeats and sub_dependant.cache_key in visited: continue flat_sub = get_flat_dependant( @@ -198,6 +229,7 @@ def get_flat_dependant( flat_dependant.cookie_params.extend(flat_sub.cookie_params) flat_dependant.body_params.extend(flat_sub.body_params) flat_dependant.security_requirements.extend(flat_sub.security_requirements) + flat_dependant.lifespan_dependencies.extend(flat_sub.lifespan_dependencies) return flat_dependant @@ -211,7 +243,7 @@ def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]: return fields -def get_flat_params(dependant: Dependant) -> List[ModelField]: +def get_flat_params(dependant: EndpointDependant) -> List[ModelField]: flat_dependant = get_flat_dependant(dependant, skip_repeats=True) path_params = _get_flat_fields_from_params(flat_dependant.path_params) query_params = _get_flat_fields_from_params(flat_dependant.query_params) @@ -254,18 +286,64 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any: return get_typed_annotation(annotation, globalns) -def get_dependant( +def get_lifespan_dependant( + *, + caller: Callable[..., Any], + call: Callable[..., Any], + name: Optional[str] = None, + use_cache: bool = True, +) -> LifespanDependant: + dependency_signature = get_typed_signature(call) + signature_params = dependency_signature.parameters + dependant = LifespanDependant( + call=call, + name=name, + use_cache=use_cache, + caller=caller + ) + for param_name, param in signature_params.items(): + param_details = analyze_param( + param_name=param_name, + annotation=param.annotation, + value=param.default, + is_path_param=False, + ) + if param_details.depends is None: + raise FastAPIError( + f"Lifespan dependency {dependant.name} was defined with an " + f"invalid argument {param_name}. Lifespan dependencies may " + f"only use other lifespan dependencies as arguments.") + + if param_details.depends.dependency_scope != "lifespan": + raise FastAPIError( + "Lifespan dependency may not use " + "sub-dependencies of other scopes." + ) + + sub_dependant = get_lifespan_dependant( + name=param_name, + call=param_details.depends.dependency, + use_cache=param_details.depends.use_cache, + caller=call + ) + dependant.dependencies.append(sub_dependant) + + return dependant + + + +def get_endpoint_dependant( *, path: str, call: Callable[..., Any], name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, -) -> Dependant: +) -> EndpointDependant: path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters - dependant = Dependant( + dependant = EndpointDependant( call=call, name=name, path=path, @@ -281,13 +359,28 @@ def get_dependant( is_path_param=is_path_param, ) if param_details.depends is not None: - sub_dependant = get_param_sub_dependant( - param_name=param_name, - depends=param_details.depends, - path=path, - security_scopes=security_scopes, - ) - dependant.dependencies.append(sub_dependant) + if param_details.depends.dependency_scope == "endpoint": + sub_dependant = get_param_sub_dependant( + param_name=param_name, + depends=param_details.depends, + path=path, + security_scopes=security_scopes, + caller=call, + ) + dependant.endpoint_dependencies.append(sub_dependant) + elif param_details.depends.dependency_scope == "lifespan": + sub_dependant = get_lifespan_dependant( + caller=call, + call=param_details.depends.dependency, + name=param_name, + use_cache=param_details.depends.use_cache, + ) + dependant.lifespan_dependencies.append(sub_dependant) + else: + raise FastAPIError( + f"Dependency \"{param_name}\" of `{call}` has an invalid " + f"sub-dependency scope: \"{param_details.depends.dependency_scope}\"" + ) continue if add_non_field_param_to_dependency( param_name=param_name, @@ -306,8 +399,12 @@ def get_dependant( return dependant +# Kept for backwards compatibility +get_dependant = get_endpoint_dependant + + def add_non_field_param_to_dependency( - *, param_name: str, type_annotation: Any, dependant: Dependant + *, param_name: str, type_annotation: Any, dependant: EndpointDependant ) -> Optional[bool]: if lenient_issubclass(type_annotation, Request): dependant.request_param_name = param_name @@ -501,7 +598,7 @@ def analyze_param( return ParamDetails(type_annotation=type_annotation, depends=depends, field=field) -def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: +def add_param_to_fields(*, field: ModelField, dependant: EndpointDependant) -> None: field_info = field.field_info field_info_in = getattr(field_info, "in_", None) if field_info_in == params.ParamTypes.path: @@ -550,6 +647,82 @@ async def solve_generator( return await stack.enter_async_context(cm) +@dataclass +class SolvedLifespanDependant: + value: Any + dependency_cache: Dict[Callable[..., Any], Any] + + +async def solve_lifespan_dependant( + *, + dependant: LifespanDependant, + dependency_overrides_provider: Optional[Any] = None, + dependency_cache: Optional[Dict[LifespanDependantCacheKey, Callable[..., Any]]] = None, + async_exit_stack: AsyncExitStack, +) -> SolvedLifespanDependant: + dependency_cache = dependency_cache or {} + if dependant.use_cache and dependant.cache_key in dependency_cache: + return SolvedLifespanDependant( + value=dependency_cache[dependant.cache_key], + dependency_cache=dependency_cache, + ) + + dependency_arguments: Dict[str, Any] = {} + sub_dependant: LifespanDependant + for sub_dependant in dependant.dependencies: + sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) + sub_dependant.cache_key = cast( + Callable[..., Any], sub_dependant.cache_key + ) + assert sub_dependant.name, ( + "Lifespan scoped dependencies should not be able to have " + "subdependencies with no name" + ) + + sub_dependant_to_solve = sub_dependant + if ( + dependency_overrides_provider + and dependency_overrides_provider.dependency_overrides + ): + original_call = sub_dependant.call + call = getattr( + dependency_overrides_provider, "dependency_overrides", {} + ).get(original_call, original_call) + sub_dependant_to_solve = get_lifespan_dependant( + call=call, + name=sub_dependant.name, + caller=dependant.call + ) + + solved_sub_dependant = await solve_lifespan_dependant( + dependant=sub_dependant_to_solve, + dependency_overrides_provider=dependency_overrides_provider, + dependency_cache=dependency_cache, + async_exit_stack=async_exit_stack, + ) + dependency_cache.update(solved_sub_dependant.dependency_cache) + dependency_arguments[sub_dependant.name] = solved_sub_dependant.value + + if is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call): + value = await solve_generator( + call=dependant.call, + stack=async_exit_stack, + sub_values=dependency_arguments + ) + elif is_coroutine_callable(dependant.call): + value = await dependant.call(**dependency_arguments) + else: + value = await run_in_threadpool(dependant.call, **dependency_arguments) + + if dependant.cache_key not in dependency_cache: + dependency_cache[dependant.cache_key] = value + + return SolvedLifespanDependant( + value=value, + dependency_cache=dependency_cache, + ) + + @dataclass class SolvedDependency: values: Dict[str, Any] @@ -562,7 +735,7 @@ class SolvedDependency: async def solve_dependencies( *, request: Union[Request, WebSocket], - dependant: Dependant, + dependant: EndpointDependant, body: Optional[Union[Dict[str, Any], FormData]] = None, background_tasks: Optional[StarletteBackgroundTasks] = None, response: Optional[Response] = None, @@ -573,13 +746,35 @@ async def solve_dependencies( ) -> SolvedDependency: values: Dict[str, Any] = {} errors: List[Any] = [] + + for sub_dependant in dependant.lifespan_dependencies: + if sub_dependant.name is None: + continue + try: + lifespan_scoped_dependencies = request.state.__fastapi__[ + "lifespan_scoped_dependencies"] + except AttributeError as e: + raise FastAPIError( + "FastAPI's internal lifespan was not initialized" + ) from e + + try: + value = lifespan_scoped_dependencies[sub_dependant.cache_key] + except KeyError as e: + raise FastAPIError( + f"Dependency {sub_dependant.name} of {dependant.call} " + f"was not initialized." + ) from e + + values[sub_dependant.name] = value + if response is None: response = Response() del response.headers["content-length"] response.status_code = None # type: ignore + dependency_cache = dependency_cache or {} - sub_dependant: Dependant - for sub_dependant in dependant.dependencies: + for sub_dependant in dependant.endpoint_dependencies: sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) sub_dependant.cache_key = cast( Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key @@ -595,7 +790,7 @@ async def solve_dependencies( dependency_overrides_provider, "dependency_overrides", {} ).get(original_call, original_call) use_path: str = sub_dependant.path # type: ignore - use_sub_dependant = get_dependant( + use_sub_dependant = get_endpoint_dependant( path=use_path, call=call, name=sub_dependant.name, @@ -910,7 +1105,7 @@ async def request_body_to_args( def get_body_field( - *, flat_dependant: Dependant, name: str, embed_body_fields: bool + *, flat_dependant: EndpointDependant, name: str, embed_body_fields: bool ) -> Optional[ModelField]: """ Get a ModelField representing the request body for a path operation, combining diff --git a/fastapi/lifespan.py b/fastapi/lifespan.py new file mode 100644 index 000000000..184c943d7 --- /dev/null +++ b/fastapi/lifespan.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Any, Callable, Dict, List + +from fastapi.dependencies.models import LifespanDependant, LifespanDependantCacheKey +from fastapi.dependencies.utils import solve_lifespan_dependant +from fastapi.routing import APIRoute + +if TYPE_CHECKING: + from fastapi import FastAPI + + +def _get_lifespan_dependants(app: FastAPI) -> List[LifespanDependant]: + lifespan_dependants_cache: Dict[LifespanDependantCacheKey, LifespanDependant] = {} + for route in app.router.routes: + if not isinstance(route, APIRoute): + continue + + for sub_dependant in route.lifespan_dependencies: + if sub_dependant.cache_key in lifespan_dependants_cache: + continue + + lifespan_dependants_cache[sub_dependant.cache_key] = sub_dependant + + return list(lifespan_dependants_cache.values()) + + +async def resolve_lifespan_dependants( + *, + app: FastAPI, + async_exit_stack: AsyncExitStack +) -> Dict[LifespanDependantCacheKey, Callable[..., Any]]: + lifespan_dependants = _get_lifespan_dependants(app) + dependency_cache: Dict[LifespanDependantCacheKey, Callable[..., Any]] = {} + for lifespan_dependant in lifespan_dependants: + solved_dependency = await solve_lifespan_dependant( + dependant=lifespan_dependant, + dependency_overrides_provider=app, + dependency_cache=dependency_cache, + async_exit_stack=async_exit_stack + ) + + dependency_cache.update(solved_dependency.dependency_cache) + + return dependency_cache diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 947eca948..425e919cf 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -15,7 +15,7 @@ from fastapi._compat import ( lenient_issubclass, ) from fastapi.datastructures import DefaultPlaceholder -from fastapi.dependencies.models import Dependant +from fastapi.dependencies.models import EndpointDependant from fastapi.dependencies.utils import ( _get_flat_fields_from_params, get_flat_dependant, @@ -75,7 +75,7 @@ status_code_ranges: Dict[str, str] = { def get_openapi_security_definitions( - flat_dependant: Dependant, + flat_dependant: EndpointDependant, ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: security_definitions = {} operation_security = [] @@ -93,7 +93,7 @@ def get_openapi_security_definitions( def _get_openapi_operation_parameters( *, - dependant: Dependant, + dependant: EndpointDependant, schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap, field_mapping: Dict[ diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index 7ddaace25..9ff744414 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -1,8 +1,11 @@ +from __future__ import annotations + from typing import Any, Callable, Dict, List, Optional, Sequence, Union from fastapi import params from fastapi._compat import Undefined from fastapi.openapi.models import Example +from fastapi.params import DependencyScope from typing_extensions import Annotated, Doc, deprecated _Unset: Any = Undefined @@ -2244,6 +2247,33 @@ def Depends( # noqa: N802 """ ), ] = True, + dependency_scope: Annotated[ + DependencyScope, + Doc( + """ + The scope in which the dependency value should be evaluated. Can be + either `"endpoint"` or `"lifespan"`. + + If `dependency_scope` is set to "endpoint" (the default), the + dependency will be setup and teardown for every request. + + If `dependency_scope` is set to `"lifespan"` the dependency would + be setup at the start of the entire application's lifespan. The + evaluated dependency would be then reused across all endpoints. + The dependency would be teared down as a part of the application's + shutdown process. + + Note that dependencies defined with the `"endpoint"` scope may use + sub-dependencies defined with the `"lifespan"` scope, but not the + other way around; + Dependencies defined with the `"lifespan"` scope may not use + sub-dependencies with `"endpoint"` scope, nor can they use + other "endpoint scoped" arguments such as "Path", "Body", "Query", + or any other annotation which does not make sense in a scope of an + application's entire lifespan. + """ + ) + ] = "endpoint" ) -> Any: """ Declare a FastAPI dependency. @@ -2274,7 +2304,7 @@ def Depends( # noqa: N802 return commons ``` """ - return params.Depends(dependency=dependency, use_cache=use_cache) + return params.Depends(dependency=dependency, use_cache=use_cache, dependency_scope=dependency_scope) def Security( # noqa: N802 diff --git a/fastapi/params.py b/fastapi/params.py index 90ca7cb01..f655acba8 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -4,11 +4,12 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union from fastapi.openapi.models import Example from pydantic.fields import FieldInfo -from typing_extensions import Annotated, deprecated +from typing_extensions import Annotated, Literal, TypeAlias, deprecated from ._compat import PYDANTIC_V2, PYDANTIC_VERSION, Undefined _Unset: Any = Undefined +DependencyScope: TypeAlias = Literal["endpoint", "lifespan"] class ParamTypes(Enum): @@ -759,15 +760,25 @@ class File(Form): class Depends: def __init__( - self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True + self, + dependency: Optional[Callable[..., Any]] = None, + *, + use_cache: bool = True, + dependency_scope: DependencyScope = "endpoint" ): self.dependency = dependency self.use_cache = use_cache + self.dependency_scope = dependency_scope def __repr__(self) -> str: attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) cache = "" if self.use_cache else ", use_cache=False" - return f"{self.__class__.__name__}({attr}{cache})" + if self.dependency_scope == "endpoint": + dependency_scope = "" + else: + dependency_scope = f", dependency_scope=\"{self.dependency_scope}\"" + + return f"{self.__class__.__name__}({attr}{cache}{dependency_scope})" class Security(Depends): @@ -778,5 +789,9 @@ class Security(Depends): scopes: Optional[Sequence[str]] = None, use_cache: bool = True, ): - super().__init__(dependency=dependency, use_cache=use_cache) + super().__init__( + dependency=dependency, + use_cache=use_cache, + dependency_scope="endpoint" + ) self.scopes = scopes or [] diff --git a/fastapi/routing.py b/fastapi/routing.py index 86e303602..c3f446037 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -31,11 +31,11 @@ from fastapi._compat import ( lenient_issubclass, ) from fastapi.datastructures import Default, DefaultPlaceholder -from fastapi.dependencies.models import Dependant +from fastapi.dependencies.models import EndpointDependant, LifespanDependant from fastapi.dependencies.utils import ( _should_embed_body_fields, get_body_field, - get_dependant, + get_endpoint_dependant, get_flat_dependant, get_parameterless_sub_dependant, get_typed_return_annotation, @@ -73,7 +73,7 @@ from starlette.routing import ( from starlette.routing import Mount as Mount # noqa from starlette.types import AppType, ASGIApp, Lifespan, Scope from starlette.websockets import WebSocket -from typing_extensions import Annotated, Doc, deprecated +from typing_extensions import Annotated, Doc, assert_never, deprecated def _prepare_response_content( @@ -123,7 +123,7 @@ def _prepare_response_content( return res -def _merge_lifespan_context( +def merge_lifespan_context( original_context: Lifespan[Any], nested_context: Lifespan[Any] ) -> Lifespan[Any]: @asynccontextmanager @@ -202,7 +202,7 @@ async def serialize_response( async def run_endpoint_function( - *, dependant: Dependant, values: Dict[str, Any], is_coroutine: bool + *, dependant: EndpointDependant, values: Dict[str, Any], is_coroutine: bool ) -> Any: # Only called by get_request_handler. Has been split into its own function to # facilitate profiling endpoints, since inner functions are harder to profile. @@ -215,7 +215,7 @@ async def run_endpoint_function( def get_request_handler( - dependant: Dependant, + dependant: EndpointDependant, body_field: Optional[ModelField] = None, status_code: Optional[int] = None, response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse), @@ -358,7 +358,7 @@ def get_request_handler( def get_websocket_app( - dependant: Dependant, + dependant: EndpointDependant, dependency_overrides_provider: Optional[Any] = None, embed_body_fields: bool = False, ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: @@ -400,12 +400,20 @@ class APIWebSocketRoute(routing.WebSocketRoute): self.name = get_name(endpoint) if name is None else name self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) - self.dependant = get_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) for depends in self.dependencies[::-1]: - self.dependant.dependencies.insert( - 0, - get_parameterless_sub_dependant(depends=depends, path=self.path_format), + sub_dependant = get_parameterless_sub_dependant( + depends=depends, + path=self.path_format, + caller=self ) + if depends.dependency_scope == "endpoint": + self.dependant.endpoint_dependencies.insert(0, sub_dependant) + elif depends.dependency_scope == "lifespan": + self.dependant.lifespan_dependencies.insert(0, sub_dependant) + else: + assert_never(depends.dependency_scope) + self._flat_dependant = get_flat_dependant(self.dependant) self._embed_body_fields = _should_embed_body_fields( self._flat_dependant.body_params @@ -424,6 +432,10 @@ class APIWebSocketRoute(routing.WebSocketRoute): child_scope["route"] = self return match, child_scope + @property + def lifespan_dependencies(self) -> List[LifespanDependant]: + return self._flat_dependant.lifespan_dependencies + class APIRoute(routing.Route): def __init__( @@ -549,12 +561,19 @@ class APIRoute(routing.Route): self.response_fields = {} assert callable(endpoint), "An endpoint must be a callable" - self.dependant = get_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) for depends in self.dependencies[::-1]: - self.dependant.dependencies.insert( - 0, - get_parameterless_sub_dependant(depends=depends, path=self.path_format), + sub_dependant = get_parameterless_sub_dependant( + depends=depends, + path=self.path_format, + caller=self.__call__ ) + if depends.dependency_scope == "endpoint": + self.dependant.endpoint_dependencies.insert(0, sub_dependant) + elif depends.dependency_scope == "lifespan": + self.dependant.lifespan_dependencies.insert(0, sub_dependant) + else: + assert_never(depends.dependency_scope) self._flat_dependant = get_flat_dependant(self.dependant) self._embed_body_fields = _should_embed_body_fields( self._flat_dependant.body_params @@ -589,6 +608,10 @@ class APIRoute(routing.Route): child_scope["route"] = self return match, child_scope + @property + def lifespan_dependencies(self) -> List[LifespanDependant]: + return self._flat_dependant.lifespan_dependencies + class APIRouter(routing.Router): """ @@ -1356,7 +1379,7 @@ class APIRouter(routing.Router): self.add_event_handler("startup", handler) for handler in router.on_shutdown: self.add_event_handler("shutdown", handler) - self.lifespan_context = _merge_lifespan_context( + self.lifespan_context = merge_lifespan_context( self.lifespan_context, router.lifespan_context, ) diff --git a/tests/test_lifespan_scoped_dependencies.py b/tests/test_lifespan_scoped_dependencies.py new file mode 100644 index 000000000..40d82f0e0 --- /dev/null +++ b/tests/test_lifespan_scoped_dependencies.py @@ -0,0 +1,703 @@ +from enum import StrEnum, auto +from typing import Any, AsyncGenerator, List, Tuple, TypeVar + +import pytest +from fastapi import ( + APIRouter, + BackgroundTasks, + Body, + Cookie, + Depends, + FastAPI, + File, + Form, + Header, + Path, + Query, +) +from fastapi.exceptions import FastAPIError +from fastapi.params import Security +from fastapi.security import SecurityScopes +from starlette.testclient import TestClient +from typing_extensions import Annotated, Generator, Literal, assert_never + +T = TypeVar('T') + + +class DependencyStyle(StrEnum): + SYNC_FUNCTION = auto() + ASYNC_FUNCTION = auto() + SYNC_GENERATOR = auto() + ASYNC_GENERATOR = auto() + + +class DependencyFactory: + def __init__( + self, + dependency_style: DependencyStyle, *, + should_error: bool = False + ): + self.activation_times = 0 + self.deactivation_times = 0 + self.dependency_style = dependency_style + self._should_error = should_error + + def get_dependency(self): + if self.dependency_style == DependencyStyle.SYNC_FUNCTION: + return self._synchronous_function_dependency + + if self.dependency_style == DependencyStyle.SYNC_GENERATOR: + return self._synchronous_generator_dependency + + if self.dependency_style == DependencyStyle.ASYNC_FUNCTION: + return self._asynchronous_function_dependency + + if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: + return self._asynchronous_generator_dependency + + assert_never(self.dependency_style) + + async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: + self.activation_times += 1 + if self._should_error: + raise ValueError(self.activation_times) + + yield self.activation_times + self.deactivation_times += 1 + + def _synchronous_generator_dependency(self) -> Generator[T, None, None]: + self.activation_times += 1 + if self._should_error: + raise ValueError(self.activation_times) + + yield self.activation_times + self.deactivation_times += 1 + + async def _asynchronous_function_dependency(self) -> T: + self.activation_times += 1 + if self._should_error: + raise ValueError(self.activation_times) + + return self.activation_times + + def _synchronous_function_dependency(self) -> T: + self.activation_times += 1 + if self._should_error: + raise ValueError(self.activation_times) + + return self.activation_times + + +def _expect_correct_amount_of_dependency_activations( + *, + app: FastAPI, + dependency_factory: DependencyFactory, + urls_and_responses: List[Tuple[str, Any]], + expected_activation_times: int +) -> None: + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + with TestClient(app) as client: + assert dependency_factory.activation_times == expected_activation_times + assert dependency_factory.deactivation_times == 0 + + for url, expected_response in urls_and_responses: + response = client.post(url) + response.raise_for_status() + assert response.json() == expected_response + + assert dependency_factory.activation_times == expected_activation_times + assert dependency_factory.deactivation_times == 0 + + assert dependency_factory.activation_times == expected_activation_times + if dependency_factory.dependency_style not in ( + DependencyStyle.SYNC_FUNCTION, + DependencyStyle.ASYNC_FUNCTION + ): + assert dependency_factory.deactivation_times == expected_activation_times + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, use_cache): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + @router.post("/test") + async def endpoint( + dependency: Annotated[None, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + )] + ) -> None: + assert dependency == 1 + return dependency + + if routing_style == "router_endpoint": + app.include_router(router) + + _expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + urls_and_responses=[("/test", 1)] * 2, + expected_activation_times=1 + ) + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_router_dependencies( + dependency_style: DependencyStyle, + routing_style, + use_cache +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + if routing_style == "app": + app = FastAPI(dependencies=[depends]) + + @app.post("/test") + async def endpoint() -> None: + return None + else: + app = FastAPI() + router = APIRouter(dependencies=[depends]) + + @router.post("/test") + async def endpoint() -> None: + return None + + app.include_router(router) + + _expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + urls_and_responses=[("/test", None)] * 2, + expected_activation_times=1 + ) + + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) +def test_dependency_cache_in_same_dependency( + dependency_style: DependencyStyle, + routing_style, + use_cache, + main_dependency_scope: Literal["endpoint", "lifespan"] +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def dependency( + sub_dependency1: Annotated[int, depends], + sub_dependency2: Annotated[int, depends], + ) -> List[int]: + return [sub_dependency1, sub_dependency2] + + @router.post("/test") + async def endpoint( + dependency: Annotated[List[int], Depends( + dependency, + use_cache=use_cache, + dependency_scope=main_dependency_scope, + )] + ) -> List[int]: + return dependency + + if routing_style == "router": + app.include_router(router) + + if use_cache: + _expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 1]), + ("/test", [1, 1]), + ], + dependency_factory=dependency_factory, + expected_activation_times=1 + ) + else: + _expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 2]), + ("/test", [1, 2]), + ], + dependency_factory=dependency_factory, + expected_activation_times=2 + ) + + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_same_endpoint( + dependency_style: DependencyStyle, + routing_style, + use_cache +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + @router.post("/test1") + async def endpoint( + dependency1: Annotated[int, depends], + dependency2: Annotated[int, depends], + dependency3: Annotated[int, Depends(endpoint_dependency)] + ) -> List[int]: + return [dependency1, dependency2, dependency3] + + if routing_style == "router": + app.include_router(router) + + if use_cache: + _expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [1, 1, 1]), + ("/test1", [1, 1, 1]), + ], + dependency_factory=dependency_factory, + expected_activation_times=1 + ) + else: + _expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [1, 2, 3]), + ("/test1", [1, 2, 3]), + ], + dependency_factory=dependency_factory, + expected_activation_times=3 + ) + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_different_endpoints( + dependency_style: DependencyStyle, + routing_style, + use_cache +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + @router.post("/test1") + async def endpoint( + dependency1: Annotated[int, depends], + dependency2: Annotated[int, depends], + dependency3: Annotated[int, Depends(endpoint_dependency)] + ) -> List[int]: + return [dependency1, dependency2, dependency3] + + @router.post("/test2") + async def endpoint2( + dependency1: Annotated[int, depends], + dependency2: Annotated[int, depends], + dependency3: Annotated[int, Depends(endpoint_dependency)] + ) -> List[int]: + return [dependency1, dependency2, dependency3] + + if routing_style == "router": + app.include_router(router) + + if use_cache: + _expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [1, 1, 1]), + ("/test2", [1, 1, 1]), + ("/test1", [1, 1, 1]), + ("/test2", [1, 1, 1]), + ], + dependency_factory=dependency_factory, + expected_activation_times=1 + ) + else: + _expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [1, 2, 3]), + ("/test2", [4, 5, 3]), + ("/test1", [1, 2, 3]), + ("/test2", [4, 5, 3]), + ], + dependency_factory=dependency_factory, + expected_activation_times=5 + ) + +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_no_cached_dependency( + dependency_style: DependencyStyle, + routing_style, +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=False + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + @router.post("/test") + async def endpoint( + dependency: Annotated[int, depends], + ) -> int: + return dependency + + if routing_style == "router": + app.include_router(router) + + _expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + urls_and_responses=[("/test", 1)] * 2, + expected_activation_times=1 + ) + + +@pytest.mark.parametrize("annotation", [ + Annotated[str, Path()], + Annotated[str, Body()], + Annotated[str, Query()], + Annotated[str, Header()], + SecurityScopes, + Annotated[str, Cookie()], + Annotated[str, Form()], + Annotated[str, File()], + BackgroundTasks, +]) +def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( + annotation +): + async def dependency_func(param: annotation) -> None: + yield + + app = FastAPI() + + with pytest.raises(FastAPIError): + @app.post("/test") + async def endpoint( + dependency: Annotated[ + None, Depends(dependency_func, dependency_scope="lifespan")] + ) -> None: + return + + +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( + dependency_style: DependencyStyle +): + dependency_factory = DependencyFactory(dependency_style) + + async def lifespan_scoped_dependency( + param: Annotated[int, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan" + )] + ) -> AsyncGenerator[int, None]: + yield param + + app = FastAPI() + + @app.post("/test") + async def endpoint( + dependency: Annotated[int, Depends( + lifespan_scoped_dependency, + dependency_scope="lifespan" + )] + ) -> int: + return dependency + + _expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + expected_activation_times=1, + urls_and_responses=[("/test", 1)] * 2 + ) + + +@pytest.mark.parametrize("depends_class", [Depends, Security]) +@pytest.mark.parametrize("route_type", [FastAPI.post, FastAPI.websocket], ids=[ + "websocket", "endpoint" +]) +def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( + depends_class, + route_type +): + async def sub_dependency() -> None: + pass + + async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: + yield + + app = FastAPI() + route_decorator = route_type(app, "/test") + + with pytest.raises(FastAPIError): + @route_decorator + async def endpoint(x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] + ) -> None: + return + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_dependencies_must_provide_correct_dependency_scope( + dependency_style: DependencyStyle, + routing_style, + use_cache +): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + with pytest.raises(FastAPIError): + @router.post("/test") + async def endpoint( + dependency: Annotated[None, Depends( + dependency_factory.get_dependency(), + dependency_scope="incorrect", + use_cache=use_cache, + )] + ) -> None: + assert dependency == 1 + return dependency + + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoints_report_incorrect_dependency_scope( + dependency_style: DependencyStyle, + routing_style, + use_cache +): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + # We intentionally change the dependency scope here to bypass the + # validation at the function level. + depends.dependency_scope = "asdad" + + with pytest.raises(FastAPIError): + @router.post("/test") + async def endpoint( + dependency: Annotated[int, depends] + ) -> int: + assert dependency == 1 + return dependency + + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoints_report_uninitialized_dependency( + dependency_style: DependencyStyle, + routing_style, + use_cache +): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + @router.post("/test") + async def endpoint( + dependency: Annotated[int, depends] + ) -> int: + assert dependency == 1 + return dependency + + if routing_style == "router_endpoint": + app.include_router(router) + + with TestClient(app) as client: + dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] + client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {} + + try: + with pytest.raises(FastAPIError): + client.post("/test") + finally: + client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies + + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoints_report_uninitialized_internal_lifespan( + dependency_style: DependencyStyle, + routing_style, + use_cache +): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + @router.post("/test") + async def endpoint( + dependency: Annotated[int, depends] + ) -> int: + assert dependency == 1 + return dependency + + if routing_style == "router_endpoint": + app.include_router(router) + + with TestClient(app) as client: + internal_state = client.app_state["__fastapi__"] + del client.app_state["__fastapi__"] + + try: + with pytest.raises(FastAPIError): + client.post("/test") + finally: + client.app_state["__fastapi__"] = internal_state + + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: DependencyStyle, routing_style): + dependency_factory= DependencyFactory(dependency_style, should_error=True) + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + + else: + router = APIRouter() + + @router.post("/test") + async def endpoint( + dependency: Annotated[int, depends] + ) -> int: + assert dependency == 1 + return dependency + + if routing_style == "router_endpoint": + app.include_router(router) + + with pytest.raises(ValueError) as exception_info: + with TestClient(app): + pass + + assert exception_info.value.args == (1,) + + +# TODO: Add tests for dependency_overrides +# TODO: Add a websocket equivalent to all tests diff --git a/tests/test_params_repr.py b/tests/test_params_repr.py index bfc7bed09..10f044888 100644 --- a/tests/test_params_repr.py +++ b/tests/test_params_repr.py @@ -1,5 +1,6 @@ from typing import Any, List +import pytest from dirty_equals import IsOneOf from fastapi.params import Body, Cookie, Depends, Header, Param, Path, Query @@ -143,10 +144,16 @@ def test_body_repr_list(): assert repr(Body([])) == "Body([])" -def test_depends_repr(): - assert repr(Depends()) == "Depends(NoneType)" - assert repr(Depends(get_user)) == "Depends(get_user)" - assert repr(Depends(use_cache=False)) == "Depends(NoneType, use_cache=False)" - assert ( - repr(Depends(get_user, use_cache=False)) == "Depends(get_user, use_cache=False)" - ) +@pytest.mark.parametrize(["depends", "expected_repr"], [ + [Depends(), "Depends(NoneType)"], + [Depends(get_user), "Depends(get_user)"], + [Depends(use_cache=False), "Depends(NoneType, use_cache=False)"], + [Depends(get_user, use_cache=False), "Depends(get_user, use_cache=False)"], + + [Depends(dependency_scope="lifespan"), "Depends(NoneType, dependency_scope=\"lifespan\")"], + [Depends(get_user, dependency_scope="lifespan"), "Depends(get_user, dependency_scope=\"lifespan\")"], + [Depends(use_cache=False, dependency_scope="lifespan"), "Depends(NoneType, use_cache=False, dependency_scope=\"lifespan\")"], + [Depends(get_user, use_cache=False, dependency_scope="lifespan"), "Depends(get_user, use_cache=False, dependency_scope=\"lifespan\")"], +]) +def test_depends_repr(depends, expected_repr): + assert repr(depends) == expected_repr diff --git a/tests/test_router_events.py b/tests/test_router_events.py index dd7ff3314..8289a7301 100644 --- a/tests/test_router_events.py +++ b/tests/test_router_events.py @@ -199,6 +199,9 @@ def test_router_nested_lifespan_state_overriding_by_parent() -> None: "app_specific": True, "router_specific": True, "overridden": "app", + "__fastapi__": { + "lifespan_scoped_dependencies": {} + }, } @@ -216,7 +219,11 @@ def test_merged_no_return_lifespans_return_none() -> None: app.include_router(router) with TestClient(app) as client: - assert not client.app_state + assert client.app_state == { + "__fastapi__": { + "lifespan_scoped_dependencies": {} + } + } def test_merged_mixed_state_lifespans() -> None: @@ -239,4 +246,9 @@ def test_merged_mixed_state_lifespans() -> None: app.include_router(router) with TestClient(app) as client: - assert client.app_state == {"router": True} + assert client.app_state == { + "router": True, + "__fastapi__": { + "lifespan_scoped_dependencies": {} + } + } From 54ecfb87d87cdc7c087c1381f2f543be7e6ed039 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Oct 2024 13:12:24 +0000 Subject: [PATCH 02/29] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/applications.py | 11 +- fastapi/dependencies/models.py | 14 +- fastapi/dependencies/utils.py | 59 ++--- fastapi/lifespan.py | 6 +- fastapi/param_functions.py | 8 +- fastapi/params.py | 8 +- fastapi/routing.py | 16 +- tests/test_lifespan_scoped_dependencies.py | 261 +++++++++++---------- tests/test_params_repr.py | 36 ++- tests/test_router_events.py | 14 +- 10 files changed, 214 insertions(+), 219 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index 625690a74..6b2b90336 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -938,10 +938,7 @@ class FastAPI(Starlette): if lifespan is None: lifespan = FastAPI._internal_lifespan else: - lifespan = merge_lifespan_context( - FastAPI._internal_lifespan, - lifespan - ) + lifespan = merge_lifespan_context(FastAPI._internal_lifespan, lifespan) # Since we always use a lifespan, starlette will no longer run event # handlers which are defined in the scope of the application. @@ -985,8 +982,7 @@ class FastAPI(Starlette): async def _internal_lifespan(self) -> AsyncGenerator[dict[str, Any], None]: async with AsyncExitStack() as exit_stack: lifespan_scoped_dependencies = await resolve_lifespan_dependants( - app=self, - async_exit_stack=exit_stack + app=self, async_exit_stack=exit_stack ) try: for handler in self._on_startup: @@ -1006,7 +1002,6 @@ class FastAPI(Starlette): else: await run_in_threadpool(handler) - def openapi(self) -> Dict[str, Any]: """ Generate the OpenAPI schema of the application. This is called by FastAPI @@ -4536,12 +4531,14 @@ class FastAPI(Starlette): Read more about it in the [FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/#alternative-events-deprecated). """ + def decorator(func: DecoratedCallable) -> DecoratedCallable: if event_type == "startup": self._on_startup.append(func) else: self._on_shutdown.append(func) return func + return decorator def middleware( diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 471f9c402..df72e7f5f 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -12,7 +12,10 @@ class SecurityRequirement: scopes: Optional[Sequence[str]] = None -LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], str], Callable[..., Any]] +LifespanDependantCacheKey: TypeAlias = Union[ + Tuple[Callable[..., Any], str], Callable[..., Any] +] + @dataclass class LifespanDependant: @@ -30,7 +33,10 @@ class LifespanDependant: self.cache_key = (self.caller, self.name) -EndpointDependantCacheKey: TypeAlias = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] +EndpointDependantCacheKey: TypeAlias = Tuple[ + Optional[Callable[..., Any]], Tuple[str, ...] +] + @dataclass class EndpointDependant: @@ -39,8 +45,7 @@ class EndpointDependant: name: Optional[str] = None call: Optional[Callable[..., Any]] = None use_cache: bool = True - cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field( - init=False) + cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) path_params: List[ModelField] = field(default_factory=list) query_params: List[ModelField] = field(default_factory=list) header_params: List[ModelField] = field(default_factory=list) @@ -64,6 +69,7 @@ class EndpointDependant: def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]: return tuple(self.endpoint_dependencies + self.lifespan_dependencies) + # Kept for backwards compatibility Dependant = EndpointDependant CacheKey: TypeAlias = Union[EndpointDependantCacheKey, LifespanDependantCacheKey] diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 0e9cfe244..527224584 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -134,19 +134,13 @@ def get_param_sub_dependant( def get_parameterless_sub_dependant( - *, - depends: params.Depends, - path: str, - caller: Callable[..., Any] + *, depends: params.Depends, path: str, caller: Callable[..., Any] ) -> Union[EndpointDependant, LifespanDependant]: assert callable( depends.dependency ), "A parameter-less dependency must have a callable dependency" return get_sub_dependant( - depends=depends, - dependency=depends.dependency, - path=path, - caller=caller + depends=depends, dependency=depends.dependency, path=path, caller=caller ) @@ -164,7 +158,7 @@ def get_sub_dependant( caller=caller, call=depends.dependency, name=name, - use_cache=depends.use_cache + use_cache=depends.use_cache, ) elif depends.dependency_scope == "endpoint": security_requirement = None @@ -215,7 +209,7 @@ def get_flat_dependant( security_requirements=dependant.security_requirements.copy(), lifespan_dependencies=dependant.lifespan_dependencies.copy(), use_cache=dependant.use_cache, - path=dependant.path + path=dependant.path, ) for sub_dependant in dependant.endpoint_dependencies: if skip_repeats and sub_dependant.cache_key in visited: @@ -296,10 +290,7 @@ def get_lifespan_dependant( dependency_signature = get_typed_signature(call) signature_params = dependency_signature.parameters dependant = LifespanDependant( - call=call, - name=name, - use_cache=use_cache, - caller=caller + call=call, name=name, use_cache=use_cache, caller=caller ) for param_name, param in signature_params.items(): param_details = analyze_param( @@ -312,26 +303,25 @@ def get_lifespan_dependant( raise FastAPIError( f"Lifespan dependency {dependant.name} was defined with an " f"invalid argument {param_name}. Lifespan dependencies may " - f"only use other lifespan dependencies as arguments.") + f"only use other lifespan dependencies as arguments." + ) if param_details.depends.dependency_scope != "lifespan": raise FastAPIError( - "Lifespan dependency may not use " - "sub-dependencies of other scopes." + "Lifespan dependency may not use " "sub-dependencies of other scopes." ) sub_dependant = get_lifespan_dependant( name=param_name, call=param_details.depends.dependency, use_cache=param_details.depends.use_cache, - caller=call + caller=call, ) dependant.dependencies.append(sub_dependant) return dependant - def get_endpoint_dependant( *, path: str, @@ -378,8 +368,8 @@ def get_endpoint_dependant( dependant.lifespan_dependencies.append(sub_dependant) else: raise FastAPIError( - f"Dependency \"{param_name}\" of `{call}` has an invalid " - f"sub-dependency scope: \"{param_details.depends.dependency_scope}\"" + f'Dependency "{param_name}" of `{call}` has an invalid ' + f'sub-dependency scope: "{param_details.depends.dependency_scope}"' ) continue if add_non_field_param_to_dependency( @@ -659,7 +649,9 @@ async def solve_lifespan_dependant( *, dependant: LifespanDependant, dependency_overrides_provider: Optional[Any] = None, - dependency_cache: Optional[Dict[LifespanDependantCacheKey, Callable[..., Any]]] = None, + dependency_cache: Optional[ + Dict[LifespanDependantCacheKey, Callable[..., Any]] + ] = None, async_exit_stack: AsyncExitStack, ) -> SolvedLifespanDependant: dependency_cache = dependency_cache or {} @@ -673,9 +665,7 @@ async def solve_lifespan_dependant( sub_dependant: LifespanDependant for sub_dependant in dependant.dependencies: sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) - sub_dependant.cache_key = cast( - Callable[..., Any], sub_dependant.cache_key - ) + sub_dependant.cache_key = cast(Callable[..., Any], sub_dependant.cache_key) assert sub_dependant.name, ( "Lifespan scoped dependencies should not be able to have " "subdependencies with no name" @@ -691,9 +681,7 @@ async def solve_lifespan_dependant( dependency_overrides_provider, "dependency_overrides", {} ).get(original_call, original_call) sub_dependant_to_solve = get_lifespan_dependant( - call=call, - name=sub_dependant.name, - caller=dependant.call + call=call, name=sub_dependant.name, caller=dependant.call ) solved_sub_dependant = await solve_lifespan_dependant( @@ -707,9 +695,7 @@ async def solve_lifespan_dependant( if is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call): value = await solve_generator( - call=dependant.call, - stack=async_exit_stack, - sub_values=dependency_arguments + call=dependant.call, stack=async_exit_stack, sub_values=dependency_arguments ) elif is_coroutine_callable(dependant.call): value = await dependant.call(**dependency_arguments) @@ -754,18 +740,17 @@ async def solve_dependencies( continue try: lifespan_scoped_dependencies = request.state.__fastapi__[ - "lifespan_scoped_dependencies"] + "lifespan_scoped_dependencies" + ] except AttributeError as e: - raise FastAPIError( - "FastAPI's internal lifespan was not initialized" - ) from e + raise FastAPIError("FastAPI's internal lifespan was not initialized") from e try: value = lifespan_scoped_dependencies[sub_dependant.cache_key] except KeyError as e: raise FastAPIError( - f"Dependency {sub_dependant.name} of {dependant.call} " - f"was not initialized." + f"Dependency {sub_dependant.name} of {dependant.call} " + f"was not initialized." ) from e values[sub_dependant.name] = value diff --git a/fastapi/lifespan.py b/fastapi/lifespan.py index 184c943d7..bff285267 100644 --- a/fastapi/lifespan.py +++ b/fastapi/lifespan.py @@ -27,9 +27,7 @@ def _get_lifespan_dependants(app: FastAPI) -> List[LifespanDependant]: async def resolve_lifespan_dependants( - *, - app: FastAPI, - async_exit_stack: AsyncExitStack + *, app: FastAPI, async_exit_stack: AsyncExitStack ) -> Dict[LifespanDependantCacheKey, Callable[..., Any]]: lifespan_dependants = _get_lifespan_dependants(app) dependency_cache: Dict[LifespanDependantCacheKey, Callable[..., Any]] = {} @@ -38,7 +36,7 @@ async def resolve_lifespan_dependants( dependant=lifespan_dependant, dependency_overrides_provider=app, dependency_cache=dependency_cache, - async_exit_stack=async_exit_stack + async_exit_stack=async_exit_stack, ) dependency_cache.update(solved_dependency.dependency_cache) diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index e9a15c166..4112e90e2 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -2272,8 +2272,8 @@ def Depends( # noqa: N802 or any other annotation which does not make sense in a scope of an application's entire lifespan. """ - ) - ] = "endpoint" + ), + ] = "endpoint", ) -> Any: """ Declare a FastAPI dependency. @@ -2304,7 +2304,9 @@ def Depends( # noqa: N802 return commons ``` """ - return params.Depends(dependency=dependency, use_cache=use_cache, dependency_scope=dependency_scope) + return params.Depends( + dependency=dependency, use_cache=use_cache, dependency_scope=dependency_scope + ) def Security( # noqa: N802 diff --git a/fastapi/params.py b/fastapi/params.py index f655acba8..954d4b27d 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -764,7 +764,7 @@ class Depends: dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True, - dependency_scope: DependencyScope = "endpoint" + dependency_scope: DependencyScope = "endpoint", ): self.dependency = dependency self.use_cache = use_cache @@ -776,7 +776,7 @@ class Depends: if self.dependency_scope == "endpoint": dependency_scope = "" else: - dependency_scope = f", dependency_scope=\"{self.dependency_scope}\"" + dependency_scope = f', dependency_scope="{self.dependency_scope}"' return f"{self.__class__.__name__}({attr}{cache}{dependency_scope})" @@ -790,8 +790,6 @@ class Security(Depends): use_cache: bool = True, ): super().__init__( - dependency=dependency, - use_cache=use_cache, - dependency_scope="endpoint" + dependency=dependency, use_cache=use_cache, dependency_scope="endpoint" ) self.scopes = scopes or [] diff --git a/fastapi/routing.py b/fastapi/routing.py index 50b51a352..24ea39268 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -400,12 +400,12 @@ class APIWebSocketRoute(routing.WebSocketRoute): self.name = get_name(endpoint) if name is None else name self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) - self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_endpoint_dependant( + path=self.path_format, call=self.endpoint + ) for depends in self.dependencies[::-1]: sub_dependant = get_parameterless_sub_dependant( - depends=depends, - path=self.path_format, - caller=self + depends=depends, path=self.path_format, caller=self ) if depends.dependency_scope == "endpoint": self.dependant.endpoint_dependencies.insert(0, sub_dependant) @@ -563,12 +563,12 @@ class APIRoute(routing.Route): self.response_fields = {} assert callable(endpoint), "An endpoint must be a callable" - self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_endpoint_dependant( + path=self.path_format, call=self.endpoint + ) for depends in self.dependencies[::-1]: sub_dependant = get_parameterless_sub_dependant( - depends=depends, - path=self.path_format, - caller=self.__call__ + depends=depends, path=self.path_format, caller=self.__call__ ) if depends.dependency_scope == "endpoint": self.dependant.endpoint_dependencies.insert(0, sub_dependant) diff --git a/tests/test_lifespan_scoped_dependencies.py b/tests/test_lifespan_scoped_dependencies.py index 40d82f0e0..70cbc5bb3 100644 --- a/tests/test_lifespan_scoped_dependencies.py +++ b/tests/test_lifespan_scoped_dependencies.py @@ -21,7 +21,7 @@ from fastapi.security import SecurityScopes from starlette.testclient import TestClient from typing_extensions import Annotated, Generator, Literal, assert_never -T = TypeVar('T') +T = TypeVar("T") class DependencyStyle(StrEnum): @@ -33,9 +33,7 @@ class DependencyStyle(StrEnum): class DependencyFactory: def __init__( - self, - dependency_style: DependencyStyle, *, - should_error: bool = False + self, dependency_style: DependencyStyle, *, should_error: bool = False ): self.activation_times = 0 self.deactivation_times = 0 @@ -89,11 +87,11 @@ class DependencyFactory: def _expect_correct_amount_of_dependency_activations( - *, - app: FastAPI, - dependency_factory: DependencyFactory, - urls_and_responses: List[Tuple[str, Any]], - expected_activation_times: int + *, + app: FastAPI, + dependency_factory: DependencyFactory, + urls_and_responses: List[Tuple[str, Any]], + expected_activation_times: int, ) -> None: assert dependency_factory.activation_times == 0 assert dependency_factory.deactivation_times == 0 @@ -111,16 +109,19 @@ def _expect_correct_amount_of_dependency_activations( assert dependency_factory.activation_times == expected_activation_times if dependency_factory.dependency_style not in ( - DependencyStyle.SYNC_FUNCTION, - DependencyStyle.ASYNC_FUNCTION + DependencyStyle.SYNC_FUNCTION, + DependencyStyle.ASYNC_FUNCTION, ): assert dependency_factory.deactivation_times == expected_activation_times + @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, use_cache): - dependency_factory= DependencyFactory(dependency_style) +def test_endpoint_dependencies( + dependency_style: DependencyStyle, routing_style, use_cache +): + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -131,11 +132,14 @@ def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, @router.post("/test") async def endpoint( - dependency: Annotated[None, Depends( + dependency: Annotated[ + None, + Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", use_cache=use_cache, - )] + ), + ], ) -> None: assert dependency == 1 return dependency @@ -147,23 +151,22 @@ def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, app=app, dependency_factory=dependency_factory, urls_and_responses=[("/test", 1)] * 2, - expected_activation_times=1 + expected_activation_times=1, ) + @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_router_dependencies( - dependency_style: DependencyStyle, - routing_style, - use_cache + dependency_style: DependencyStyle, routing_style, use_cache ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) if routing_style == "app": @@ -186,7 +189,7 @@ def test_router_dependencies( app=app, dependency_factory=dependency_factory, urls_and_responses=[("/test", None)] * 2, - expected_activation_times=1 + expected_activation_times=1, ) @@ -195,17 +198,17 @@ def test_router_dependencies( @pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) def test_dependency_cache_in_same_dependency( - dependency_style: DependencyStyle, - routing_style, - use_cache, - main_dependency_scope: Literal["endpoint", "lifespan"] + dependency_style: DependencyStyle, + routing_style, + use_cache, + main_dependency_scope: Literal["endpoint", "lifespan"], ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -217,18 +220,21 @@ def test_dependency_cache_in_same_dependency( router = APIRouter() async def dependency( - sub_dependency1: Annotated[int, depends], - sub_dependency2: Annotated[int, depends], + sub_dependency1: Annotated[int, depends], + sub_dependency2: Annotated[int, depends], ) -> List[int]: return [sub_dependency1, sub_dependency2] @router.post("/test") async def endpoint( - dependency: Annotated[List[int], Depends( + dependency: Annotated[ + List[int], + Depends( dependency, use_cache=use_cache, dependency_scope=main_dependency_scope, - )] + ), + ], ) -> List[int]: return dependency @@ -243,7 +249,7 @@ def test_dependency_cache_in_same_dependency( ("/test", [1, 1]), ], dependency_factory=dependency_factory, - expected_activation_times=1 + expected_activation_times=1, ) else: _expect_correct_amount_of_dependency_activations( @@ -253,7 +259,7 @@ def test_dependency_cache_in_same_dependency( ("/test", [1, 2]), ], dependency_factory=dependency_factory, - expected_activation_times=2 + expected_activation_times=2, ) @@ -261,16 +267,14 @@ def test_dependency_cache_in_same_dependency( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_dependency_cache_in_same_endpoint( - dependency_style: DependencyStyle, - routing_style, - use_cache + dependency_style: DependencyStyle, routing_style, use_cache ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -286,9 +290,9 @@ def test_dependency_cache_in_same_endpoint( @router.post("/test1") async def endpoint( - dependency1: Annotated[int, depends], - dependency2: Annotated[int, depends], - dependency3: Annotated[int, Depends(endpoint_dependency)] + dependency1: Annotated[int, depends], + dependency2: Annotated[int, depends], + dependency3: Annotated[int, Depends(endpoint_dependency)], ) -> List[int]: return [dependency1, dependency2, dependency3] @@ -303,7 +307,7 @@ def test_dependency_cache_in_same_endpoint( ("/test1", [1, 1, 1]), ], dependency_factory=dependency_factory, - expected_activation_times=1 + expected_activation_times=1, ) else: _expect_correct_amount_of_dependency_activations( @@ -313,23 +317,22 @@ def test_dependency_cache_in_same_endpoint( ("/test1", [1, 2, 3]), ], dependency_factory=dependency_factory, - expected_activation_times=3 + expected_activation_times=3, ) + @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_dependency_cache_in_different_endpoints( - dependency_style: DependencyStyle, - routing_style, - use_cache + dependency_style: DependencyStyle, routing_style, use_cache ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -345,17 +348,17 @@ def test_dependency_cache_in_different_endpoints( @router.post("/test1") async def endpoint( - dependency1: Annotated[int, depends], - dependency2: Annotated[int, depends], - dependency3: Annotated[int, Depends(endpoint_dependency)] + dependency1: Annotated[int, depends], + dependency2: Annotated[int, depends], + dependency3: Annotated[int, Depends(endpoint_dependency)], ) -> List[int]: return [dependency1, dependency2, dependency3] @router.post("/test2") async def endpoint2( - dependency1: Annotated[int, depends], - dependency2: Annotated[int, depends], - dependency3: Annotated[int, Depends(endpoint_dependency)] + dependency1: Annotated[int, depends], + dependency2: Annotated[int, depends], + dependency3: Annotated[int, Depends(endpoint_dependency)], ) -> List[int]: return [dependency1, dependency2, dependency3] @@ -372,7 +375,7 @@ def test_dependency_cache_in_different_endpoints( ("/test2", [1, 1, 1]), ], dependency_factory=dependency_factory, - expected_activation_times=1 + expected_activation_times=1, ) else: _expect_correct_amount_of_dependency_activations( @@ -384,21 +387,22 @@ def test_dependency_cache_in_different_endpoints( ("/test2", [4, 5, 3]), ], dependency_factory=dependency_factory, - expected_activation_times=5 + expected_activation_times=5, ) + @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_no_cached_dependency( - dependency_style: DependencyStyle, - routing_style, + dependency_style: DependencyStyle, + routing_style, ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=False + use_cache=False, ) app = FastAPI() @@ -411,7 +415,7 @@ def test_no_cached_dependency( @router.post("/test") async def endpoint( - dependency: Annotated[int, depends], + dependency: Annotated[int, depends], ) -> int: return dependency @@ -422,49 +426,52 @@ def test_no_cached_dependency( app=app, dependency_factory=dependency_factory, urls_and_responses=[("/test", 1)] * 2, - expected_activation_times=1 + expected_activation_times=1, ) -@pytest.mark.parametrize("annotation", [ - Annotated[str, Path()], - Annotated[str, Body()], - Annotated[str, Query()], - Annotated[str, Header()], - SecurityScopes, - Annotated[str, Cookie()], - Annotated[str, Form()], - Annotated[str, File()], - BackgroundTasks, -]) -def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( - annotation -): +@pytest.mark.parametrize( + "annotation", + [ + Annotated[str, Path()], + Annotated[str, Body()], + Annotated[str, Query()], + Annotated[str, Header()], + SecurityScopes, + Annotated[str, Cookie()], + Annotated[str, Form()], + Annotated[str, File()], + BackgroundTasks, + ], +) +def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(annotation): async def dependency_func(param: annotation) -> None: yield app = FastAPI() with pytest.raises(FastAPIError): + @app.post("/test") async def endpoint( - dependency: Annotated[ - None, Depends(dependency_func, dependency_scope="lifespan")] + dependency: Annotated[ + None, Depends(dependency_func, dependency_scope="lifespan") + ], ) -> None: return @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( - dependency_style: DependencyStyle + dependency_style: DependencyStyle, ): dependency_factory = DependencyFactory(dependency_style) async def lifespan_scoped_dependency( - param: Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - )] + param: Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], ) -> AsyncGenerator[int, None]: yield param @@ -472,10 +479,9 @@ def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( @app.post("/test") async def endpoint( - dependency: Annotated[int, Depends( - lifespan_scoped_dependency, - dependency_scope="lifespan" - )] + dependency: Annotated[ + int, Depends(lifespan_scoped_dependency, dependency_scope="lifespan") + ], ) -> int: return dependency @@ -483,42 +489,44 @@ def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( app=app, dependency_factory=dependency_factory, expected_activation_times=1, - urls_and_responses=[("/test", 1)] * 2 + urls_and_responses=[("/test", 1)] * 2, ) @pytest.mark.parametrize("depends_class", [Depends, Security]) -@pytest.mark.parametrize("route_type", [FastAPI.post, FastAPI.websocket], ids=[ - "websocket", "endpoint" -]) +@pytest.mark.parametrize( + "route_type", [FastAPI.post, FastAPI.websocket], ids=["websocket", "endpoint"] +) def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( - depends_class, - route_type + depends_class, route_type ): async def sub_dependency() -> None: pass - async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: + async def dependency_func( + param: Annotated[None, depends_class(sub_dependency)], + ) -> None: yield app = FastAPI() route_decorator = route_type(app, "/test") with pytest.raises(FastAPIError): + @route_decorator - async def endpoint(x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] + async def endpoint( + x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")], ) -> None: return + @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_dependencies_must_provide_correct_dependency_scope( - dependency_style: DependencyStyle, - routing_style, - use_cache + dependency_style: DependencyStyle, routing_style, use_cache ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -528,13 +536,17 @@ def test_dependencies_must_provide_correct_dependency_scope( router = APIRouter() with pytest.raises(FastAPIError): + @router.post("/test") async def endpoint( - dependency: Annotated[None, Depends( + dependency: Annotated[ + None, + Depends( dependency_factory.get_dependency(), dependency_scope="incorrect", use_cache=use_cache, - )] + ), + ], ) -> None: assert dependency == 1 return dependency @@ -544,11 +556,9 @@ def test_dependencies_must_provide_correct_dependency_scope( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoints_report_incorrect_dependency_scope( - dependency_style: DependencyStyle, - routing_style, - use_cache + dependency_style: DependencyStyle, routing_style, use_cache ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -567,10 +577,9 @@ def test_endpoints_report_incorrect_dependency_scope( depends.dependency_scope = "asdad" with pytest.raises(FastAPIError): + @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: + async def endpoint(dependency: Annotated[int, depends]) -> int: assert dependency == 1 return dependency @@ -579,11 +588,9 @@ def test_endpoints_report_incorrect_dependency_scope( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoints_report_uninitialized_dependency( - dependency_style: DependencyStyle, - routing_style, - use_cache + dependency_style: DependencyStyle, routing_style, use_cache ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -599,9 +606,7 @@ def test_endpoints_report_uninitialized_dependency( ) @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: + async def endpoint(dependency: Annotated[int, depends]) -> int: assert dependency == 1 return dependency @@ -616,18 +621,18 @@ def test_endpoints_report_uninitialized_dependency( with pytest.raises(FastAPIError): client.post("/test") finally: - client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies + client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = ( + dependencies + ) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoints_report_uninitialized_internal_lifespan( - dependency_style: DependencyStyle, - routing_style, - use_cache + dependency_style: DependencyStyle, routing_style, use_cache ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -643,9 +648,7 @@ def test_endpoints_report_uninitialized_internal_lifespan( ) @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: + async def endpoint(dependency: Annotated[int, depends]) -> int: assert dependency == 1 return dependency @@ -666,8 +669,10 @@ def test_endpoints_report_uninitialized_internal_lifespan( @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: DependencyStyle, routing_style): - dependency_factory= DependencyFactory(dependency_style, should_error=True) +def test_bad_lifespan_scoped_dependencies( + use_cache, dependency_style: DependencyStyle, routing_style +): + dependency_factory = DependencyFactory(dependency_style, should_error=True) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", @@ -683,9 +688,7 @@ def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: Dependenc router = APIRouter() @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: + async def endpoint(dependency: Annotated[int, depends]) -> int: assert dependency == 1 return dependency diff --git a/tests/test_params_repr.py b/tests/test_params_repr.py index 10f044888..8921026b2 100644 --- a/tests/test_params_repr.py +++ b/tests/test_params_repr.py @@ -144,16 +144,30 @@ def test_body_repr_list(): assert repr(Body([])) == "Body([])" -@pytest.mark.parametrize(["depends", "expected_repr"], [ - [Depends(), "Depends(NoneType)"], - [Depends(get_user), "Depends(get_user)"], - [Depends(use_cache=False), "Depends(NoneType, use_cache=False)"], - [Depends(get_user, use_cache=False), "Depends(get_user, use_cache=False)"], - - [Depends(dependency_scope="lifespan"), "Depends(NoneType, dependency_scope=\"lifespan\")"], - [Depends(get_user, dependency_scope="lifespan"), "Depends(get_user, dependency_scope=\"lifespan\")"], - [Depends(use_cache=False, dependency_scope="lifespan"), "Depends(NoneType, use_cache=False, dependency_scope=\"lifespan\")"], - [Depends(get_user, use_cache=False, dependency_scope="lifespan"), "Depends(get_user, use_cache=False, dependency_scope=\"lifespan\")"], -]) +@pytest.mark.parametrize( + ["depends", "expected_repr"], + [ + [Depends(), "Depends(NoneType)"], + [Depends(get_user), "Depends(get_user)"], + [Depends(use_cache=False), "Depends(NoneType, use_cache=False)"], + [Depends(get_user, use_cache=False), "Depends(get_user, use_cache=False)"], + [ + Depends(dependency_scope="lifespan"), + 'Depends(NoneType, dependency_scope="lifespan")', + ], + [ + Depends(get_user, dependency_scope="lifespan"), + 'Depends(get_user, dependency_scope="lifespan")', + ], + [ + Depends(use_cache=False, dependency_scope="lifespan"), + 'Depends(NoneType, use_cache=False, dependency_scope="lifespan")', + ], + [ + Depends(get_user, use_cache=False, dependency_scope="lifespan"), + 'Depends(get_user, use_cache=False, dependency_scope="lifespan")', + ], + ], +) def test_depends_repr(depends, expected_repr): assert repr(depends) == expected_repr diff --git a/tests/test_router_events.py b/tests/test_router_events.py index 8289a7301..2f110e684 100644 --- a/tests/test_router_events.py +++ b/tests/test_router_events.py @@ -199,9 +199,7 @@ def test_router_nested_lifespan_state_overriding_by_parent() -> None: "app_specific": True, "router_specific": True, "overridden": "app", - "__fastapi__": { - "lifespan_scoped_dependencies": {} - }, + "__fastapi__": {"lifespan_scoped_dependencies": {}}, } @@ -219,11 +217,7 @@ def test_merged_no_return_lifespans_return_none() -> None: app.include_router(router) with TestClient(app) as client: - assert client.app_state == { - "__fastapi__": { - "lifespan_scoped_dependencies": {} - } - } + assert client.app_state == {"__fastapi__": {"lifespan_scoped_dependencies": {}}} def test_merged_mixed_state_lifespans() -> None: @@ -248,7 +242,5 @@ def test_merged_mixed_state_lifespans() -> None: with TestClient(app) as client: assert client.app_state == { "router": True, - "__fastapi__": { - "lifespan_scoped_dependencies": {} - } + "__fastapi__": {"lifespan_scoped_dependencies": {}}, } From c4860bfb7cdfff92bfc3e097e8e3d028a98a6adc Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 9 Nov 2024 09:22:14 +0200 Subject: [PATCH 03/29] Added tests for dependency overrides and websockets. Fixed bugs related to the deprecated startup and shutdown events. Fixed bugs related to dependency duplcatation within the same router scope. Made more specific dependency related exceptions. Fixed some linting and mypy related issues. --- fastapi/applications.py | 11 +- fastapi/dependencies/models.py | 26 +- fastapi/dependencies/utils.py | 164 ++-- fastapi/exceptions.py | 16 + fastapi/lifespan.py | 4 +- fastapi/routing.py | 14 +- tests/test_lifespan_scoped_dependencies.py | 703 -------------- .../__init__.py | 0 .../test_dependency_overrides.py | 634 +++++++++++++ .../test_endpoint_usage.py | 854 ++++++++++++++++++ .../testing_utilities.py | 202 +++++ 11 files changed, 1834 insertions(+), 794 deletions(-) delete mode 100644 tests/test_lifespan_scoped_dependencies.py create mode 100644 tests/test_lifespan_scoped_dependencies/__init__.py create mode 100644 tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py create mode 100644 tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py create mode 100644 tests/test_lifespan_scoped_dependencies/testing_utilities.py diff --git a/fastapi/applications.py b/fastapi/applications.py index 625690a74..7da120a34 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -946,8 +946,13 @@ class FastAPI(Starlette): # Since we always use a lifespan, starlette will no longer run event # handlers which are defined in the scope of the application. # We therefore need to call them ourselves. - self._on_startup = on_startup or [] - self._on_shutdown = on_shutdown or [] + if on_startup is None: + on_startup = [] + + if on_shutdown is None: + on_shutdown = [] + self._on_startup = list(on_startup) + self._on_shutdown = list(on_shutdown) self.router: routing.APIRouter = routing.APIRouter( routes=routes, @@ -982,7 +987,7 @@ class FastAPI(Starlette): self.setup() @asynccontextmanager - async def _internal_lifespan(self) -> AsyncGenerator[dict[str, Any], None]: + async def _internal_lifespan(self) -> AsyncGenerator[Dict[str, Any], None]: async with AsyncExitStack() as exit_stack: lifespan_scoped_dependencies = await resolve_lifespan_dependants( app=self, diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 471f9c402..b68f0339c 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast from fastapi._compat import ModelField from fastapi.security.base import SecurityBase @@ -12,22 +12,28 @@ class SecurityRequirement: scopes: Optional[Sequence[str]] = None -LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], str], Callable[..., Any]] +LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], Union[str, int]], Callable[..., Any]] @dataclass class LifespanDependant: + call: Callable[..., Any] caller: Callable[..., Any] dependencies: List["LifespanDependant"] = field(default_factory=list) name: Optional[str] = None - call: Optional[Callable[..., Any]] = None use_cache: bool = True + index: Optional[int] = None cache_key: LifespanDependantCacheKey = field(init=False) def __post_init__(self) -> None: if self.use_cache: self.cache_key = self.call - else: + elif self.name is not None: self.cache_key = (self.caller, self.name) + else: + assert self.index is not None, ( + "Lifespan dependency must have an associated name or index." + ) + self.cache_key = (self.caller, self.index) EndpointDependantCacheKey: TypeAlias = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] @@ -39,6 +45,7 @@ class EndpointDependant: name: Optional[str] = None call: Optional[Callable[..., Any]] = None use_cache: bool = True + index: Optional[int] = None cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field( init=False) path_params: List[ModelField] = field(default_factory=list) @@ -62,7 +69,16 @@ class EndpointDependant: # Kept for backwards compatibility @property def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]: - return tuple(self.endpoint_dependencies + self.lifespan_dependencies) + lifespan_dependencies = cast( + List[Union[EndpointDependant, LifespanDependant]], + self.lifespan_dependencies + ) + endpoint_dependencies = cast( + List[Union[EndpointDependant, LifespanDependant]], + self.endpoint_dependencies + ) + + return tuple(lifespan_dependencies + endpoint_dependencies) # Kept for backwards compatibility Dependant = EndpointDependant diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 0e9cfe244..46a76e8c4 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -54,11 +54,16 @@ from fastapi.concurrency import ( from fastapi.dependencies.models import ( CacheKey, EndpointDependant, + EndpointDependantCacheKey, LifespanDependant, LifespanDependantCacheKey, SecurityRequirement, ) -from fastapi.exceptions import FastAPIError +from fastapi.exceptions import ( + DependencyScopeConflict, + InvalidDependencyScope, + UninitializedLifespanDependency, +) from fastapi.logger import logger from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes @@ -78,7 +83,7 @@ from starlette.datastructures import ( from starlette.requests import HTTPConnection, Request from starlette.responses import Response from starlette.websockets import WebSocket -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import Annotated, assert_never, get_args, get_origin multipart_not_installed_error = ( 'Form data requires "python-multipart" to be installed. \n' @@ -137,7 +142,8 @@ def get_parameterless_sub_dependant( *, depends: params.Depends, path: str, - caller: Callable[..., Any] + caller: Callable[..., Any], + index: int ) -> Union[EndpointDependant, LifespanDependant]: assert callable( depends.dependency @@ -146,7 +152,8 @@ def get_parameterless_sub_dependant( depends=depends, dependency=depends.dependency, path=path, - caller=caller + caller=caller, + index=index ) @@ -158,13 +165,15 @@ def get_sub_dependant( caller: Callable[..., Any], name: Optional[str] = None, security_scopes: Optional[List[str]] = None, + index: Optional[int] = None, ) -> Union[EndpointDependant, LifespanDependant]: if depends.dependency_scope == "lifespan": return get_lifespan_dependant( caller=caller, - call=depends.dependency, + call=dependency, name=name, - use_cache=depends.use_cache + use_cache=depends.use_cache, + index=index ) elif depends.dependency_scope == "endpoint": security_requirement = None @@ -185,14 +194,15 @@ def get_sub_dependant( name=name, security_scopes=security_scopes, use_cache=depends.use_cache, + index=index ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) return sub_dependant else: - raise ValueError( - f"Dependency {name} of {caller} has an invalid " - f"sub-dependency scope: {depends.dependency_scope}" + raise InvalidDependencyScope( + f"Dependency \"{name}\" of {caller} has an invalid " + f"scope: \"{depends.dependency_scope}\"" ) @@ -292,6 +302,7 @@ def get_lifespan_dependant( call: Callable[..., Any], name: Optional[str] = None, use_cache: bool = True, + index: Optional[int] = None ) -> LifespanDependant: dependency_signature = get_typed_signature(call) signature_params = dependency_signature.parameters @@ -299,7 +310,8 @@ def get_lifespan_dependant( call=call, name=name, use_cache=use_cache, - caller=caller + caller=caller, + index=index ) for param_name, param in signature_params.items(): param_details = analyze_param( @@ -309,17 +321,23 @@ def get_lifespan_dependant( is_path_param=False, ) if param_details.depends is None: - raise FastAPIError( - f"Lifespan dependency {dependant.name} was defined with an " - f"invalid argument {param_name}. Lifespan dependencies may " - f"only use other lifespan dependencies as arguments.") + raise DependencyScopeConflict( + f"Lifespan scoped dependency \"{dependant.name}\" was defined " + f"with an invalid argument: \"{param_name}\" which is " + f"\"endpoint\" scoped. Lifespan scoped dependencies may only " + f"use lifespan scoped sub-dependencies.") if param_details.depends.dependency_scope != "lifespan": - raise FastAPIError( - "Lifespan dependency may not use " - "sub-dependencies of other scopes." + raise DependencyScopeConflict( + f"Lifespan scoped dependency {dependant.name} was defined with the " + f"sub-dependency \"{param_name}\" which is " + f"\"{param_details.depends.dependency_scope}\" scoped. " + f"Lifespan scoped dependencies may only use lifespan scoped " + f"sub-dependencies." ) + assert param_details.depends.dependency is not None + sub_dependant = get_lifespan_dependant( name=param_name, call=param_details.depends.dependency, @@ -339,6 +357,7 @@ def get_endpoint_dependant( name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, + index: Optional[int] = None ) -> EndpointDependant: path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) @@ -349,6 +368,7 @@ def get_endpoint_dependant( path=path, security_scopes=security_scopes, use_cache=use_cache, + index=index ) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names @@ -359,28 +379,19 @@ def get_endpoint_dependant( is_path_param=is_path_param, ) if param_details.depends is not None: - if param_details.depends.dependency_scope == "endpoint": - sub_dependant = get_param_sub_dependant( - param_name=param_name, - depends=param_details.depends, - path=path, - security_scopes=security_scopes, - caller=call, - ) + sub_dependant = get_param_sub_dependant( + param_name=param_name, + depends=param_details.depends, + path=path, + security_scopes=security_scopes, + caller=call, + ) + if isinstance(sub_dependant, EndpointDependant): dependant.endpoint_dependencies.append(sub_dependant) - elif param_details.depends.dependency_scope == "lifespan": - sub_dependant = get_lifespan_dependant( - caller=call, - call=param_details.depends.dependency, - name=param_name, - use_cache=param_details.depends.use_cache, - ) + elif isinstance(sub_dependant, LifespanDependant): dependant.lifespan_dependencies.append(sub_dependant) else: - raise FastAPIError( - f"Dependency \"{param_name}\" of `{call}` has an invalid " - f"sub-dependency scope: \"{param_details.depends.dependency_scope}\"" - ) + assert_never(sub_dependant) continue if add_non_field_param_to_dependency( param_name=param_name, @@ -652,7 +663,7 @@ async def solve_generator( @dataclass class SolvedLifespanDependant: value: Any - dependency_cache: Dict[Callable[..., Any], Any] + dependency_cache: Dict[LifespanDependantCacheKey, Any] async def solve_lifespan_dependant( @@ -669,35 +680,33 @@ async def solve_lifespan_dependant( dependency_cache=dependency_cache, ) - dependency_arguments: Dict[str, Any] = {} - sub_dependant: LifespanDependant - for sub_dependant in dependant.dependencies: - sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) - sub_dependant.cache_key = cast( - Callable[..., Any], sub_dependant.cache_key + call = dependant.call + dependant_to_solve = dependant + if ( + dependency_overrides_provider + and dependency_overrides_provider.dependency_overrides + ): + call = getattr( + dependency_overrides_provider, + "dependency_overrides", + {} + ).get(dependant.call, dependant.call) + dependant_to_solve = get_lifespan_dependant( + caller=dependant.caller, + call=call, + name=dependant.name, + use_cache=dependant.use_cache, + index=dependant.index, ) + + dependency_arguments: Dict[str, Any] = {} + for sub_dependant in dependant_to_solve.dependencies: assert sub_dependant.name, ( "Lifespan scoped dependencies should not be able to have " "subdependencies with no name" ) - - sub_dependant_to_solve = sub_dependant - if ( - dependency_overrides_provider - and dependency_overrides_provider.dependency_overrides - ): - original_call = sub_dependant.call - call = getattr( - dependency_overrides_provider, "dependency_overrides", {} - ).get(original_call, original_call) - sub_dependant_to_solve = get_lifespan_dependant( - call=call, - name=sub_dependant.name, - caller=dependant.call - ) - solved_sub_dependant = await solve_lifespan_dependant( - dependant=sub_dependant_to_solve, + dependant=sub_dependant, dependency_overrides_provider=dependency_overrides_provider, dependency_cache=dependency_cache, async_exit_stack=async_exit_stack, @@ -705,16 +714,16 @@ async def solve_lifespan_dependant( dependency_cache.update(solved_sub_dependant.dependency_cache) dependency_arguments[sub_dependant.name] = solved_sub_dependant.value - if is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call): + if is_gen_callable(call) or is_async_gen_callable(call): value = await solve_generator( - call=dependant.call, + call=call, stack=async_exit_stack, sub_values=dependency_arguments ) - elif is_coroutine_callable(dependant.call): - value = await dependant.call(**dependency_arguments) + elif is_coroutine_callable(call): + value = await call(**dependency_arguments) else: - value = await run_in_threadpool(dependant.call, **dependency_arguments) + value = await run_in_threadpool(call, **dependency_arguments) if dependant.cache_key not in dependency_cache: dependency_cache[dependant.cache_key] = value @@ -731,7 +740,7 @@ class SolvedDependency: errors: List[Any] background_tasks: Optional[StarletteBackgroundTasks] response: Response - dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] + dependency_cache: Dict[EndpointDependantCacheKey, Any] async def solve_dependencies( @@ -742,33 +751,34 @@ async def solve_dependencies( background_tasks: Optional[StarletteBackgroundTasks] = None, response: Optional[Response] = None, dependency_overrides_provider: Optional[Any] = None, - dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, + dependency_cache: Optional[Dict[EndpointDependantCacheKey, Any]] = None, async_exit_stack: AsyncExitStack, embed_body_fields: bool, ) -> SolvedDependency: values: Dict[str, Any] = {} errors: List[Any] = [] - for sub_dependant in dependant.lifespan_dependencies: - if sub_dependant.name is None: + for lifespan_sub_dependant in dependant.lifespan_dependencies: + if lifespan_sub_dependant.name is None: continue + try: lifespan_scoped_dependencies = request.state.__fastapi__[ "lifespan_scoped_dependencies"] - except AttributeError as e: - raise FastAPIError( - "FastAPI's internal lifespan was not initialized" + except (AttributeError, KeyError) as e: + raise UninitializedLifespanDependency( + "FastAPI's internal lifespan was not initialized correctly." ) from e try: - value = lifespan_scoped_dependencies[sub_dependant.cache_key] + value = lifespan_scoped_dependencies[lifespan_sub_dependant.cache_key] except KeyError as e: - raise FastAPIError( - f"Dependency {sub_dependant.name} of {dependant.call} " - f"was not initialized." + raise UninitializedLifespanDependency( + f"Dependency \"{lifespan_sub_dependant.name}\" of " + f"`{dependant.call}` was not initialized correctly." ) from e - values[sub_dependant.name] = value + values[lifespan_sub_dependant.name] = value if response is None: response = Response() diff --git a/fastapi/exceptions.py b/fastapi/exceptions.py index 44d4ada86..95fd60477 100644 --- a/fastapi/exceptions.py +++ b/fastapi/exceptions.py @@ -146,6 +146,22 @@ class FastAPIError(RuntimeError): """ +class DependencyError(FastAPIError): + pass + + +class InvalidDependencyScope(DependencyError): + pass + + +class DependencyScopeConflict(DependencyError): + pass + + +class UninitializedLifespanDependency(DependencyError): + pass + + class ValidationException(Exception): def __init__(self, errors: Sequence[Any]) -> None: self._errors = errors diff --git a/fastapi/lifespan.py b/fastapi/lifespan.py index 184c943d7..d725a5e30 100644 --- a/fastapi/lifespan.py +++ b/fastapi/lifespan.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List from fastapi.dependencies.models import LifespanDependant, LifespanDependantCacheKey from fastapi.dependencies.utils import solve_lifespan_dependant -from fastapi.routing import APIRoute +from fastapi.routing import APIRoute, APIWebSocketRoute if TYPE_CHECKING: from fastapi import FastAPI @@ -14,7 +14,7 @@ if TYPE_CHECKING: def _get_lifespan_dependants(app: FastAPI) -> List[LifespanDependant]: lifespan_dependants_cache: Dict[LifespanDependantCacheKey, LifespanDependant] = {} for route in app.router.routes: - if not isinstance(route, APIRoute): + if not isinstance(route, (APIWebSocketRoute, APIRoute)): continue for sub_dependant in route.lifespan_dependencies: diff --git a/fastapi/routing.py b/fastapi/routing.py index 50b51a352..e11edaa13 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -401,15 +401,18 @@ class APIWebSocketRoute(routing.WebSocketRoute): self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) - for depends in self.dependencies[::-1]: + for i, depends in list(enumerate(self.dependencies))[::-1]: sub_dependant = get_parameterless_sub_dependant( depends=depends, path=self.path_format, - caller=self + caller=self.__call__, + index=i ) if depends.dependency_scope == "endpoint": + assert isinstance(sub_dependant, EndpointDependant) self.dependant.endpoint_dependencies.insert(0, sub_dependant) elif depends.dependency_scope == "lifespan": + assert isinstance(sub_dependant, LifespanDependant) self.dependant.lifespan_dependencies.insert(0, sub_dependant) else: assert_never(depends.dependency_scope) @@ -564,15 +567,18 @@ class APIRoute(routing.Route): assert callable(endpoint), "An endpoint must be a callable" self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) - for depends in self.dependencies[::-1]: + for i, depends in list(enumerate(self.dependencies))[::-1]: sub_dependant = get_parameterless_sub_dependant( depends=depends, path=self.path_format, - caller=self.__call__ + caller=self.__call__, + index=i ) if depends.dependency_scope == "endpoint": + assert isinstance(sub_dependant, EndpointDependant) self.dependant.endpoint_dependencies.insert(0, sub_dependant) elif depends.dependency_scope == "lifespan": + assert isinstance(sub_dependant, LifespanDependant) self.dependant.lifespan_dependencies.insert(0, sub_dependant) else: assert_never(depends.dependency_scope) diff --git a/tests/test_lifespan_scoped_dependencies.py b/tests/test_lifespan_scoped_dependencies.py deleted file mode 100644 index 40d82f0e0..000000000 --- a/tests/test_lifespan_scoped_dependencies.py +++ /dev/null @@ -1,703 +0,0 @@ -from enum import StrEnum, auto -from typing import Any, AsyncGenerator, List, Tuple, TypeVar - -import pytest -from fastapi import ( - APIRouter, - BackgroundTasks, - Body, - Cookie, - Depends, - FastAPI, - File, - Form, - Header, - Path, - Query, -) -from fastapi.exceptions import FastAPIError -from fastapi.params import Security -from fastapi.security import SecurityScopes -from starlette.testclient import TestClient -from typing_extensions import Annotated, Generator, Literal, assert_never - -T = TypeVar('T') - - -class DependencyStyle(StrEnum): - SYNC_FUNCTION = auto() - ASYNC_FUNCTION = auto() - SYNC_GENERATOR = auto() - ASYNC_GENERATOR = auto() - - -class DependencyFactory: - def __init__( - self, - dependency_style: DependencyStyle, *, - should_error: bool = False - ): - self.activation_times = 0 - self.deactivation_times = 0 - self.dependency_style = dependency_style - self._should_error = should_error - - def get_dependency(self): - if self.dependency_style == DependencyStyle.SYNC_FUNCTION: - return self._synchronous_function_dependency - - if self.dependency_style == DependencyStyle.SYNC_GENERATOR: - return self._synchronous_generator_dependency - - if self.dependency_style == DependencyStyle.ASYNC_FUNCTION: - return self._asynchronous_function_dependency - - if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: - return self._asynchronous_generator_dependency - - assert_never(self.dependency_style) - - async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: - self.activation_times += 1 - if self._should_error: - raise ValueError(self.activation_times) - - yield self.activation_times - self.deactivation_times += 1 - - def _synchronous_generator_dependency(self) -> Generator[T, None, None]: - self.activation_times += 1 - if self._should_error: - raise ValueError(self.activation_times) - - yield self.activation_times - self.deactivation_times += 1 - - async def _asynchronous_function_dependency(self) -> T: - self.activation_times += 1 - if self._should_error: - raise ValueError(self.activation_times) - - return self.activation_times - - def _synchronous_function_dependency(self) -> T: - self.activation_times += 1 - if self._should_error: - raise ValueError(self.activation_times) - - return self.activation_times - - -def _expect_correct_amount_of_dependency_activations( - *, - app: FastAPI, - dependency_factory: DependencyFactory, - urls_and_responses: List[Tuple[str, Any]], - expected_activation_times: int -) -> None: - assert dependency_factory.activation_times == 0 - assert dependency_factory.deactivation_times == 0 - with TestClient(app) as client: - assert dependency_factory.activation_times == expected_activation_times - assert dependency_factory.deactivation_times == 0 - - for url, expected_response in urls_and_responses: - response = client.post(url) - response.raise_for_status() - assert response.json() == expected_response - - assert dependency_factory.activation_times == expected_activation_times - assert dependency_factory.deactivation_times == 0 - - assert dependency_factory.activation_times == expected_activation_times - if dependency_factory.dependency_style not in ( - DependencyStyle.SYNC_FUNCTION, - DependencyStyle.ASYNC_FUNCTION - ): - assert dependency_factory.deactivation_times == expected_activation_times - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, use_cache): - dependency_factory= DependencyFactory(dependency_style) - - app = FastAPI() - - if routing_style == "app_endpoint": - router = app - else: - router = APIRouter() - - @router.post("/test") - async def endpoint( - dependency: Annotated[None, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache, - )] - ) -> None: - assert dependency == 1 - return dependency - - if routing_style == "router_endpoint": - app.include_router(router) - - _expect_correct_amount_of_dependency_activations( - app=app, - dependency_factory=dependency_factory, - urls_and_responses=[("/test", 1)] * 2, - expected_activation_times=1 - ) - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app", "router"]) -def test_router_dependencies( - dependency_style: DependencyStyle, - routing_style, - use_cache -): - dependency_factory= DependencyFactory(dependency_style) - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache - ) - - if routing_style == "app": - app = FastAPI(dependencies=[depends]) - - @app.post("/test") - async def endpoint() -> None: - return None - else: - app = FastAPI() - router = APIRouter(dependencies=[depends]) - - @router.post("/test") - async def endpoint() -> None: - return None - - app.include_router(router) - - _expect_correct_amount_of_dependency_activations( - app=app, - dependency_factory=dependency_factory, - urls_and_responses=[("/test", None)] * 2, - expected_activation_times=1 - ) - - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app", "router"]) -@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) -def test_dependency_cache_in_same_dependency( - dependency_style: DependencyStyle, - routing_style, - use_cache, - main_dependency_scope: Literal["endpoint", "lifespan"] -): - dependency_factory= DependencyFactory(dependency_style) - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache - ) - - app = FastAPI() - - if routing_style == "app": - router = app - - else: - router = APIRouter() - - async def dependency( - sub_dependency1: Annotated[int, depends], - sub_dependency2: Annotated[int, depends], - ) -> List[int]: - return [sub_dependency1, sub_dependency2] - - @router.post("/test") - async def endpoint( - dependency: Annotated[List[int], Depends( - dependency, - use_cache=use_cache, - dependency_scope=main_dependency_scope, - )] - ) -> List[int]: - return dependency - - if routing_style == "router": - app.include_router(router) - - if use_cache: - _expect_correct_amount_of_dependency_activations( - app=app, - urls_and_responses=[ - ("/test", [1, 1]), - ("/test", [1, 1]), - ], - dependency_factory=dependency_factory, - expected_activation_times=1 - ) - else: - _expect_correct_amount_of_dependency_activations( - app=app, - urls_and_responses=[ - ("/test", [1, 2]), - ("/test", [1, 2]), - ], - dependency_factory=dependency_factory, - expected_activation_times=2 - ) - - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app", "router"]) -def test_dependency_cache_in_same_endpoint( - dependency_style: DependencyStyle, - routing_style, - use_cache -): - dependency_factory= DependencyFactory(dependency_style) - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache - ) - - app = FastAPI() - - if routing_style == "app": - router = app - - else: - router = APIRouter() - - async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: - return dependency3 - - @router.post("/test1") - async def endpoint( - dependency1: Annotated[int, depends], - dependency2: Annotated[int, depends], - dependency3: Annotated[int, Depends(endpoint_dependency)] - ) -> List[int]: - return [dependency1, dependency2, dependency3] - - if routing_style == "router": - app.include_router(router) - - if use_cache: - _expect_correct_amount_of_dependency_activations( - app=app, - urls_and_responses=[ - ("/test1", [1, 1, 1]), - ("/test1", [1, 1, 1]), - ], - dependency_factory=dependency_factory, - expected_activation_times=1 - ) - else: - _expect_correct_amount_of_dependency_activations( - app=app, - urls_and_responses=[ - ("/test1", [1, 2, 3]), - ("/test1", [1, 2, 3]), - ], - dependency_factory=dependency_factory, - expected_activation_times=3 - ) - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app", "router"]) -def test_dependency_cache_in_different_endpoints( - dependency_style: DependencyStyle, - routing_style, - use_cache -): - dependency_factory= DependencyFactory(dependency_style) - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache - ) - - app = FastAPI() - - if routing_style == "app": - router = app - - else: - router = APIRouter() - - async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: - return dependency3 - - @router.post("/test1") - async def endpoint( - dependency1: Annotated[int, depends], - dependency2: Annotated[int, depends], - dependency3: Annotated[int, Depends(endpoint_dependency)] - ) -> List[int]: - return [dependency1, dependency2, dependency3] - - @router.post("/test2") - async def endpoint2( - dependency1: Annotated[int, depends], - dependency2: Annotated[int, depends], - dependency3: Annotated[int, Depends(endpoint_dependency)] - ) -> List[int]: - return [dependency1, dependency2, dependency3] - - if routing_style == "router": - app.include_router(router) - - if use_cache: - _expect_correct_amount_of_dependency_activations( - app=app, - urls_and_responses=[ - ("/test1", [1, 1, 1]), - ("/test2", [1, 1, 1]), - ("/test1", [1, 1, 1]), - ("/test2", [1, 1, 1]), - ], - dependency_factory=dependency_factory, - expected_activation_times=1 - ) - else: - _expect_correct_amount_of_dependency_activations( - app=app, - urls_and_responses=[ - ("/test1", [1, 2, 3]), - ("/test2", [4, 5, 3]), - ("/test1", [1, 2, 3]), - ("/test2", [4, 5, 3]), - ], - dependency_factory=dependency_factory, - expected_activation_times=5 - ) - -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app", "router"]) -def test_no_cached_dependency( - dependency_style: DependencyStyle, - routing_style, -): - dependency_factory= DependencyFactory(dependency_style) - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=False - ) - - app = FastAPI() - - if routing_style == "app": - router = app - - else: - router = APIRouter() - - @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends], - ) -> int: - return dependency - - if routing_style == "router": - app.include_router(router) - - _expect_correct_amount_of_dependency_activations( - app=app, - dependency_factory=dependency_factory, - urls_and_responses=[("/test", 1)] * 2, - expected_activation_times=1 - ) - - -@pytest.mark.parametrize("annotation", [ - Annotated[str, Path()], - Annotated[str, Body()], - Annotated[str, Query()], - Annotated[str, Header()], - SecurityScopes, - Annotated[str, Cookie()], - Annotated[str, Form()], - Annotated[str, File()], - BackgroundTasks, -]) -def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( - annotation -): - async def dependency_func(param: annotation) -> None: - yield - - app = FastAPI() - - with pytest.raises(FastAPIError): - @app.post("/test") - async def endpoint( - dependency: Annotated[ - None, Depends(dependency_func, dependency_scope="lifespan")] - ) -> None: - return - - -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( - dependency_style: DependencyStyle -): - dependency_factory = DependencyFactory(dependency_style) - - async def lifespan_scoped_dependency( - param: Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - )] - ) -> AsyncGenerator[int, None]: - yield param - - app = FastAPI() - - @app.post("/test") - async def endpoint( - dependency: Annotated[int, Depends( - lifespan_scoped_dependency, - dependency_scope="lifespan" - )] - ) -> int: - return dependency - - _expect_correct_amount_of_dependency_activations( - app=app, - dependency_factory=dependency_factory, - expected_activation_times=1, - urls_and_responses=[("/test", 1)] * 2 - ) - - -@pytest.mark.parametrize("depends_class", [Depends, Security]) -@pytest.mark.parametrize("route_type", [FastAPI.post, FastAPI.websocket], ids=[ - "websocket", "endpoint" -]) -def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( - depends_class, - route_type -): - async def sub_dependency() -> None: - pass - - async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: - yield - - app = FastAPI() - route_decorator = route_type(app, "/test") - - with pytest.raises(FastAPIError): - @route_decorator - async def endpoint(x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] - ) -> None: - return - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_dependencies_must_provide_correct_dependency_scope( - dependency_style: DependencyStyle, - routing_style, - use_cache -): - dependency_factory= DependencyFactory(dependency_style) - - app = FastAPI() - - if routing_style == "app_endpoint": - router = app - else: - router = APIRouter() - - with pytest.raises(FastAPIError): - @router.post("/test") - async def endpoint( - dependency: Annotated[None, Depends( - dependency_factory.get_dependency(), - dependency_scope="incorrect", - use_cache=use_cache, - )] - ) -> None: - assert dependency == 1 - return dependency - - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_endpoints_report_incorrect_dependency_scope( - dependency_style: DependencyStyle, - routing_style, - use_cache -): - dependency_factory= DependencyFactory(dependency_style) - - app = FastAPI() - - if routing_style == "app_endpoint": - router = app - else: - router = APIRouter() - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache, - ) - # We intentionally change the dependency scope here to bypass the - # validation at the function level. - depends.dependency_scope = "asdad" - - with pytest.raises(FastAPIError): - @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: - assert dependency == 1 - return dependency - - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_endpoints_report_uninitialized_dependency( - dependency_style: DependencyStyle, - routing_style, - use_cache -): - dependency_factory= DependencyFactory(dependency_style) - - app = FastAPI() - - if routing_style == "app_endpoint": - router = app - else: - router = APIRouter() - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache, - ) - - @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: - assert dependency == 1 - return dependency - - if routing_style == "router_endpoint": - app.include_router(router) - - with TestClient(app) as client: - dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] - client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {} - - try: - with pytest.raises(FastAPIError): - client.post("/test") - finally: - client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies - - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_endpoints_report_uninitialized_internal_lifespan( - dependency_style: DependencyStyle, - routing_style, - use_cache -): - dependency_factory= DependencyFactory(dependency_style) - - app = FastAPI() - - if routing_style == "app_endpoint": - router = app - else: - router = APIRouter() - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache, - ) - - @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: - assert dependency == 1 - return dependency - - if routing_style == "router_endpoint": - app.include_router(router) - - with TestClient(app) as client: - internal_state = client.app_state["__fastapi__"] - del client.app_state["__fastapi__"] - - try: - with pytest.raises(FastAPIError): - client.post("/test") - finally: - client.app_state["__fastapi__"] = internal_state - - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: DependencyStyle, routing_style): - dependency_factory= DependencyFactory(dependency_style, should_error=True) - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache, - ) - - app = FastAPI() - - if routing_style == "app_endpoint": - router = app - - else: - router = APIRouter() - - @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: - assert dependency == 1 - return dependency - - if routing_style == "router_endpoint": - app.include_router(router) - - with pytest.raises(ValueError) as exception_info: - with TestClient(app): - pass - - assert exception_info.value.args == (1,) - - -# TODO: Add tests for dependency_overrides -# TODO: Add a websocket equivalent to all tests diff --git a/tests/test_lifespan_scoped_dependencies/__init__.py b/tests/test_lifespan_scoped_dependencies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py new file mode 100644 index 000000000..430765aae --- /dev/null +++ b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py @@ -0,0 +1,634 @@ +from typing import Any, AsyncGenerator, List, Tuple + +import pytest +from fastapi import ( + APIRouter, + BackgroundTasks, + Body, + Cookie, + Depends, + FastAPI, + File, + Form, + Header, + Path, + Query, +) +from fastapi.exceptions import DependencyScopeConflict +from fastapi.params import Security +from fastapi.security import SecurityScopes +from fastapi.testclient import TestClient +from typing_extensions import Annotated, Literal + +from tests.test_lifespan_scoped_dependencies.testing_utilities import ( + DependencyFactory, + DependencyStyle, + IntentionallyBadDependency, + create_endpoint_0_annotations, + create_endpoint_1_annotation, + create_endpoint_3_annotations, + use_endpoint, + use_websocket, +) + + +def expect_correct_amount_of_dependency_activations( + *, + app: FastAPI, + dependency_factory: DependencyFactory, + override_dependency_factory: DependencyFactory, + urls_and_responses: List[Tuple[str, Any]], + expected_activation_times: int, + is_websocket: bool +) -> None: + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + assert override_dependency_factory.activation_times == 0 + assert override_dependency_factory.deactivation_times == 0 + + with TestClient(app) as client: + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + assert override_dependency_factory.activation_times == expected_activation_times + assert override_dependency_factory.deactivation_times == 0 + + for url, expected_response in urls_and_responses: + if is_websocket: + response = use_websocket(client, url) + else: + response = use_endpoint(client, url) + + assert response == expected_response + + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + assert override_dependency_factory.activation_times == expected_activation_times + assert override_dependency_factory.deactivation_times == 0 + + assert dependency_factory.activation_times == 0 + assert override_dependency_factory.activation_times == expected_activation_times + if dependency_factory.dependency_style not in ( + DependencyStyle.SYNC_FUNCTION, + DependencyStyle.ASYNC_FUNCTION + ): + assert dependency_factory.deactivation_times == 0 + assert override_dependency_factory.deactivation_times == expected_activation_times + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoint_dependencies( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory( + dependency_style, + value_offset=10 + ) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + None, + Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + ], + expected_value=11 + ) + if routing_style == "router_endpoint": + app.include_router(router) + + app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + urls_and_responses=[("/test", 11)] * 2, + expected_activation_times=1, + is_websocket=is_websocket + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("dependency_duplication", [1, 2]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_router_dependencies( + dependency_style: DependencyStyle, + routing_style, + use_cache, + dependency_duplication, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory( + dependency_style, + value_offset=10 + ) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + if routing_style == "app": + app = FastAPI(dependencies=[depends] * dependency_duplication) + + create_endpoint_0_annotations( + router=app, + path="/test", + is_websocket=is_websocket + ) + else: + app = FastAPI() + router = APIRouter(dependencies=[depends] * dependency_duplication) + + create_endpoint_0_annotations( + router=router, + path="/test", + is_websocket=is_websocket + ) + + app.include_router(router) + + app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + urls_and_responses=[("/test", None)] * 2, + expected_activation_times=1 if use_cache else dependency_duplication, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) +def test_dependency_cache_in_same_dependency( + dependency_style: DependencyStyle, + routing_style, + use_cache, + main_dependency_scope: Literal["endpoint", "lifespan"], + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory( + dependency_style, + value_offset=10 + ) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def dependency( + sub_dependency1: Annotated[int, depends], + sub_dependency2: Annotated[int, depends], + ) -> List[int]: + return [sub_dependency1, sub_dependency2] + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[List[int], Depends( + dependency, + use_cache=use_cache, + dependency_scope=main_dependency_scope, + )] + ) + + if routing_style == "router": + app.include_router(router) + + app.dependency_overrides[ + dependency_factory.get_dependency() + ] = override_dependency_factory.get_dependency() + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [11, 11]), + ("/test", [11, 11]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [11, 12]), + ("/test", [11, 12]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=2, + is_websocket=is_websocket + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_same_endpoint( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory( + dependency_style, + value_offset=10 + ) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + create_endpoint_3_annotations( + router=router, + path="/test1", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)], + ) + + if routing_style == "router": + app.include_router(router) + + app.dependency_overrides[ + dependency_factory.get_dependency() + ] = override_dependency_factory.get_dependency() + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [11, 11, 11]), + ("/test1", [11, 11, 11]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [11, 12, 13]), + ("/test1", [11, 12, 13]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=3, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_different_endpoints( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory( + dependency_style, + value_offset=10 + ) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + create_endpoint_3_annotations( + router=router, + path="/test1", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)], + ) + + create_endpoint_3_annotations( + router=router, + path="/test2", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)], + ) + + if routing_style == "router": + app.include_router(router) + + app.dependency_overrides[ + dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [11, 11, 11]), + ("/test2", [11, 11, 11]), + ("/test1", [11, 11, 11]), + ("/test2", [11, 11, 11]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [11, 12, 13]), + ("/test2", [14, 15, 13]), + ("/test1", [11, 12, 13]), + ("/test2", [14, 15, 13]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=5, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_no_cached_dependency( + dependency_style: DependencyStyle, + routing_style, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory( + dependency_style, + value_offset=10 + ) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=False + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + ) + + if routing_style == "router": + app.include_router(router) + + app.dependency_overrides[ + dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + urls_and_responses=[("/test", 11)] * 2, + expected_activation_times=1, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("annotation", [ + Annotated[str, Path()], + Annotated[str, Body()], + Annotated[str, Query()], + Annotated[str, Header()], + SecurityScopes, + Annotated[str, Cookie()], + Annotated[str, Form()], + Annotated[str, File()], + BackgroundTasks, +]) +def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( + annotation, + is_websocket +): + async def dependency_func() -> None: + yield + + async def override_dependency_func(param: annotation) -> None: + yield + + app = FastAPI() + app.dependency_overrides[dependency_func] = override_dependency_func + + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[None, + Depends(dependency_func, dependency_scope="lifespan") + ] + ) + + with pytest.raises(DependencyScopeConflict): + with TestClient(app): + pass + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies( + dependency_style: DependencyStyle, + is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory( + dependency_style, + value_offset=10 + ) + + async def lifespan_scoped_dependency( + param: Annotated[int, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan" + )] + ) -> AsyncGenerator[int, None]: + yield param + + app = FastAPI() + + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + int, + Depends(lifespan_scoped_dependency, dependency_scope="lifespan") + ], + ) + + app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=1, + urls_and_responses=[("/test", 11)] * 2, + is_websocket=is_websocket + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("depends_class", [Depends, Security]) +def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( + depends_class, + is_websocket +): + async def sub_dependency() -> None: + pass + + async def dependency_func() -> None: + yield + + async def override_dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: + yield + + app = FastAPI() + + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] + ) + + app.dependency_overrides[dependency_func] = override_dependency_func + + with pytest.raises(DependencyScopeConflict): + with TestClient(app): + pass + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_bad_override_lifespan_scoped_dependencies( + use_cache, + dependency_style: DependencyStyle, + routing_style, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, should_error=True) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends] + ) + + if routing_style == "router_endpoint": + app.include_router(router) + + app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + + with pytest.raises(IntentionallyBadDependency) as exception_info: + with TestClient(app): + pass + + assert exception_info.value.args == (1,) diff --git a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py new file mode 100644 index 000000000..ccf8d896a --- /dev/null +++ b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py @@ -0,0 +1,854 @@ +import warnings +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator, Dict, List, Tuple + +import pytest +from fastapi import ( + APIRouter, + BackgroundTasks, + Body, + Cookie, + Depends, + FastAPI, + File, + Form, + Header, + Path, + Query, +) +from fastapi.exceptions import ( + DependencyScopeConflict, + InvalidDependencyScope, + UninitializedLifespanDependency, +) +from fastapi.params import Security +from fastapi.security import SecurityScopes +from fastapi.testclient import TestClient +from typing_extensions import Annotated, Literal, assert_never + +from tests.test_lifespan_scoped_dependencies.testing_utilities import ( + DependencyFactory, + DependencyStyle, + IntentionallyBadDependency, + create_endpoint_0_annotations, + create_endpoint_1_annotation, + create_endpoint_2_annotations, + create_endpoint_3_annotations, + use_endpoint, + use_websocket, +) + + +def expect_correct_amount_of_dependency_activations( + *, + app: FastAPI, + dependency_factory: DependencyFactory, + urls_and_responses: List[Tuple[str, Any]], + expected_activation_times: int, + is_websocket: bool +) -> None: + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + with TestClient(app) as client: + assert dependency_factory.activation_times == expected_activation_times + assert dependency_factory.deactivation_times == 0 + + for url, expected_response in urls_and_responses: + if is_websocket: + assert use_websocket(client, url) == expected_response + else: + assert use_endpoint(client, url) == expected_response + + assert dependency_factory.activation_times == expected_activation_times + assert dependency_factory.deactivation_times == 0 + + assert dependency_factory.activation_times == expected_activation_times + if dependency_factory.dependency_style not in ( + DependencyStyle.SYNC_FUNCTION, + DependencyStyle.ASYNC_FUNCTION + ): + assert dependency_factory.deactivation_times == expected_activation_times + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False], ids=["With Cache", "Without Cache"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoint_dependencies( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket: bool, +): + dependency_factory = DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[None, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + )], + expected_value=1 + ) + + if routing_style == "router_endpoint": + app.include_router(router) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + urls_and_responses=[("/test", 1)] * 2, + expected_activation_times=1, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("dependency_duplication", [1, 2]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_router_dependencies( + dependency_style: DependencyStyle, + routing_style, + use_cache, + dependency_duplication, + is_websocket: bool, +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + if routing_style == "app": + app = FastAPI(dependencies=[depends] * dependency_duplication) + + create_endpoint_0_annotations( + router=app, + path="/test", + is_websocket=is_websocket + ) + else: + app = FastAPI() + router = APIRouter(dependencies=[depends] * dependency_duplication) + + create_endpoint_0_annotations( + router=router, + path="/test", + is_websocket=is_websocket + ) + + app.include_router(router) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + urls_and_responses=[("/test", None)] * 2, + expected_activation_times=1 if use_cache else dependency_duplication, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) +def test_dependency_cache_in_same_dependency( + dependency_style: DependencyStyle, + routing_style, + use_cache, + main_dependency_scope: Literal["endpoint", "lifespan"], + is_websocket: bool, +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def dependency( + sub_dependency1: Annotated[int, depends], + sub_dependency2: Annotated[int, depends], + ) -> List[int]: + return [sub_dependency1, sub_dependency2] + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[List[int], Depends( + dependency, + use_cache=use_cache, + dependency_scope=main_dependency_scope, + )] + ) + + if routing_style == "router": + app.include_router(router) + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 1]), + ("/test", [1, 1]), + ], + dependency_factory=dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 2]), + ("/test", [1, 2]), + ], + dependency_factory=dependency_factory, + expected_activation_times=2, + is_websocket=is_websocket + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_same_endpoint( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + create_endpoint_3_annotations( + router=router, + path="/test", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)] + ) + + if routing_style == "router": + app.include_router(router) + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 1, 1]), + ("/test", [1, 1, 1]), + ], + dependency_factory=dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 2, 3]), + ("/test", [1, 2, 3]), + ], + dependency_factory=dependency_factory, + expected_activation_times=3, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_different_endpoints( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + create_endpoint_3_annotations( + router=router, + path="/test1", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)] + ) + + create_endpoint_3_annotations( + router=router, + path="/test2", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)] + ) + + if routing_style == "router": + app.include_router(router) + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [1, 1, 1]), + ("/test2", [1, 1, 1]), + ("/test1", [1, 1, 1]), + ("/test2", [1, 1, 1]), + ], + dependency_factory=dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [1, 2, 3]), + ("/test2", [4, 5, 3]), + ("/test1", [1, 2, 3]), + ("/test2", [4, 5, 3]), + ], + dependency_factory=dependency_factory, + expected_activation_times=5, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_no_cached_dependency( + dependency_style: DependencyStyle, + routing_style, + is_websocket, +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=False + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + expected_value=1 + ) + + if routing_style == "router": + app.include_router(router) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + urls_and_responses=[("/test", 1)] * 2, + expected_activation_times=1, + is_websocket=is_websocket + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("annotation", [ + Annotated[str, Path()], + Annotated[str, Body()], + Annotated[str, Query()], + Annotated[str, Header()], + SecurityScopes, + Annotated[str, Cookie()], + Annotated[str, Form()], + Annotated[str, File()], + BackgroundTasks, +]) +def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( + annotation, + is_websocket +): + async def dependency_func(param: annotation) -> None: + yield + + app = FastAPI() + + with pytest.raises(DependencyScopeConflict): + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + None, + Depends(dependency_func, dependency_scope="lifespan") + ], + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( + dependency_style: DependencyStyle, + is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + + async def lifespan_scoped_dependency( + param: Annotated[int, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan" + )] + ) -> AsyncGenerator[int, None]: + yield param + + app = FastAPI() + + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, Depends(lifespan_scoped_dependency)], + expected_value=1 + ) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + expected_activation_times=1, + urls_and_responses=[("/test", 1)] * 2, + is_websocket=is_websocket + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize([ + "dependency_style", + "supports_teardown" +], [ + (DependencyStyle.SYNC_FUNCTION, False), + (DependencyStyle.ASYNC_FUNCTION, False), + (DependencyStyle.SYNC_GENERATOR, True), + (DependencyStyle.ASYNC_GENERATOR, True), +]) +def test_the_same_dependency_can_work_in_different_scopes( + dependency_style: DependencyStyle, + supports_teardown, + is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + app = FastAPI() + + create_endpoint_2_annotations( + router=app, + path="/test", + is_websocket=is_websocket, + annotation1=Annotated[int, Depends( + dependency_factory.get_dependency(), + dependency_scope="endpoint" + )], + annotation2=Annotated[int, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan" + )], + ) + if is_websocket: + get_response = use_websocket + else: + get_response = use_endpoint + + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + with TestClient(app) as client: + assert dependency_factory.activation_times == 1 + assert dependency_factory.deactivation_times == 0 + + assert get_response(client, "/test") == [2, 1] + assert dependency_factory.activation_times == 2 + if supports_teardown: + assert dependency_factory.deactivation_times == 1 + else: + assert dependency_factory.deactivation_times == 0 + + assert get_response(client, "/test") == [3, 1] + assert dependency_factory.activation_times == 3 + if supports_teardown: + assert dependency_factory.deactivation_times == 2 + else: + assert dependency_factory.deactivation_times == 0 + + assert dependency_factory.activation_times == 3 + if supports_teardown: + assert dependency_factory.deactivation_times == 3 + else: + assert dependency_factory.deactivation_times == 0 + + +@pytest.mark.parametrize("lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( + dependency_style: DependencyStyle, + is_websocket, + lifespan_style: Literal["lifespan_function", "lifespan_events"] +): + lifespan_started = False + lifespan_ended = False + if lifespan_style == "lifespan_generator": + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, int], None]: + nonlocal lifespan_started + nonlocal lifespan_ended + lifespan_started = True + yield + lifespan_ended = True + + app = FastAPI(lifespan=lifespan) + elif lifespan_style == "events_decorator": + app = FastAPI() + with warnings.catch_warnings(action="ignore", category=DeprecationWarning): + @app.on_event("startup") + async def startup() -> None: + nonlocal lifespan_started + lifespan_started = True + + @app.on_event("shutdown") + async def shutdown() -> None: + nonlocal lifespan_ended + lifespan_ended = True + elif lifespan_style == "events_constructor": + async def startup() -> None: + nonlocal lifespan_started + lifespan_started = True + + async def shutdown() -> None: + nonlocal lifespan_ended + lifespan_ended = True + app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) + else: + assert_never(lifespan_style) + + dependency_factory = DependencyFactory(dependency_style) + + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan" + )], + expected_value=1 + ) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + expected_activation_times=1, + urls_and_responses=[("/test", 1)] * 2, + is_websocket=is_websocket + ) + assert lifespan_started and lifespan_ended + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("depends_class", [Depends, Security]) +def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( + depends_class, + is_websocket +): + async def sub_dependency() -> None: + pass + + async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: + yield + + app = FastAPI() + + with pytest.raises(DependencyScopeConflict): + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")], + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_dependencies_must_provide_correct_dependency_scope( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + with pytest.raises( + InvalidDependencyScope, + match=r'Dependency "value" of .* has an invalid scope: ' + r'"incorrect"' + ): + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[None, Depends( + dependency_factory.get_dependency(), + dependency_scope="incorrect", + use_cache=use_cache, + )] + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoints_report_incorrect_dependency_scope( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + # We intentionally change the dependency scope here to bypass the + # validation at the function level. + depends.dependency_scope = "asdad" + + with pytest.raises(InvalidDependencyScope): + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends] + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoints_report_uninitialized_dependency( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + expected_value=1 + ) + + if routing_style == "router_endpoint": + app.include_router(router) + + with TestClient(app) as client: + dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] + client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {} + + try: + with pytest.raises(UninitializedLifespanDependency): + if is_websocket: + with client.websocket_connect("/test"): + pass + else: + client.post("/test") + finally: + client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoints_report_uninitialized_internal_lifespan( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + expected_value=1 + ) + + if routing_style == "router_endpoint": + app.include_router(router) + + with TestClient(app) as client: + internal_state = client.app_state["__fastapi__"] + del client.app_state["__fastapi__"] + + try: + with pytest.raises(UninitializedLifespanDependency): + if is_websocket: + with client.websocket_connect("/test"): + pass + else: + client.post("/test") + finally: + client.app_state["__fastapi__"] = internal_state + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_bad_lifespan_scoped_dependencies( + use_cache, + dependency_style: DependencyStyle, + routing_style, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style, should_error=True) + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + expected_value=1 + ) + + if routing_style == "router_endpoint": + app.include_router(router) + + with pytest.raises(IntentionallyBadDependency) as exception_info: + with TestClient(app): + pass + + assert exception_info.value.args == (1,) diff --git a/tests/test_lifespan_scoped_dependencies/testing_utilities.py b/tests/test_lifespan_scoped_dependencies/testing_utilities.py new file mode 100644 index 000000000..e733205f5 --- /dev/null +++ b/tests/test_lifespan_scoped_dependencies/testing_utilities.py @@ -0,0 +1,202 @@ +from enum import StrEnum, auto +from typing import Any, AsyncGenerator, Generator, TypeVar, Union, assert_never + +from fastapi import APIRouter, FastAPI, WebSocket +from starlette.testclient import TestClient +from starlette.websockets import WebSocketDisconnect + +T = TypeVar('T') + + +class DependencyStyle(StrEnum): + SYNC_FUNCTION = auto() + ASYNC_FUNCTION = auto() + SYNC_GENERATOR = auto() + ASYNC_GENERATOR = auto() + + +class IntentionallyBadDependency(Exception): + pass + + +class DependencyFactory: + def __init__( + self, + dependency_style: DependencyStyle, *, + should_error: bool = False, + value_offset: int = 0, + ): + self.activation_times = 0 + self.deactivation_times = 0 + self.dependency_style = dependency_style + self._should_error = should_error + self._value_offset = value_offset + + def get_dependency(self): + if self.dependency_style == DependencyStyle.SYNC_FUNCTION: + return self._synchronous_function_dependency + + if self.dependency_style == DependencyStyle.SYNC_GENERATOR: + return self._synchronous_generator_dependency + + if self.dependency_style == DependencyStyle.ASYNC_FUNCTION: + return self._asynchronous_function_dependency + + if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: + return self._asynchronous_generator_dependency + + assert_never(self.dependency_style) + + async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: + self.activation_times += 1 + if self._should_error: + raise IntentionallyBadDependency(self.activation_times) + + yield self.activation_times + self._value_offset + self.deactivation_times += 1 + + def _synchronous_generator_dependency(self) -> Generator[T, None, None]: + self.activation_times += 1 + if self._should_error: + raise IntentionallyBadDependency(self.activation_times) + + yield self.activation_times + self._value_offset + self.deactivation_times += 1 + + async def _asynchronous_function_dependency(self) -> T: + self.activation_times += 1 + if self._should_error: + raise IntentionallyBadDependency(self.activation_times) + + return self.activation_times + self._value_offset + + def _synchronous_function_dependency(self) -> T: + self.activation_times += 1 + if self._should_error: + raise IntentionallyBadDependency(self.activation_times) + + return self.activation_times + self._value_offset + + +def use_endpoint(client: TestClient, url: str) -> Any: + response = client.post(url) + response.raise_for_status() + return response.json() + + +def use_websocket(client: TestClient, url: str) -> Any: + with client.websocket_connect(url) as connection: + return connection.receive_json() + + +def create_endpoint_0_annotations( + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, +) -> None: + if is_websocket: + @router.websocket(path) + async def endpoint(websocket: WebSocket) -> None: + await websocket.accept() + try: + await websocket.send_json(None) + except WebSocketDisconnect: + pass + else: + @router.post(path) + async def endpoint() -> None: + return None + + +def create_endpoint_1_annotation( + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, + annotation: Any, + expected_value: Any = None +) -> None: + if is_websocket: + @router.websocket(path) + async def endpoint( + websocket: WebSocket, + value: annotation + ) -> None: + if expected_value is not None: + assert value == expected_value + + await websocket.accept() + try: + await websocket.send_json(value) + except WebSocketDisconnect: + pass + else: + @router.post(path) + async def endpoint( + value: annotation + ) -> None: + if expected_value is not None: + assert value == expected_value + + return value + +def create_endpoint_2_annotations( + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, + annotation1: Any, + annotation2: Any, +) -> None: + if is_websocket: + @router.websocket(path) + async def endpoint( + websocket: WebSocket, + value1: annotation1, + value2: annotation2, + ) -> None: + await websocket.accept() + try: + await websocket.send_json([value1, value2]) + except WebSocketDisconnect: + await websocket.close() + else: + @router.post(path) + async def endpoint( + value1: annotation1, + value2: annotation2, + ) -> list[Any]: + return [value1, value2] + + +def create_endpoint_3_annotations( + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, + annotation1: Any, + annotation2: Any, + annotation3: Any +) -> None: + if is_websocket: + @router.websocket(path) + async def endpoint( + websocket: WebSocket, + value1: annotation1, + value2: annotation2, + value3: annotation3 + ) -> None: + await websocket.accept() + try: + await websocket.send_json([value1, value2, value3]) + except WebSocketDisconnect: + await websocket.close() + else: + @router.post(path) + async def endpoint( + value1: annotation1, + value2: annotation2, + value3: annotation3 + ) -> list[Any]: + return [value1, value2, value3] From d512b03bd3d65b38e82edd138d26bac880e75e4c Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 9 Nov 2024 09:33:58 +0200 Subject: [PATCH 04/29] Applied ruff linting --- fastapi/dependencies/models.py | 18 +- fastapi/dependencies/utils.py | 68 ++-- fastapi/routing.py | 18 +- .../test_dependency_overrides.py | 280 +++++++-------- .../test_endpoint_usage.py | 334 +++++++++--------- .../testing_utilities.py | 97 ++--- 6 files changed, 397 insertions(+), 418 deletions(-) diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 2ea813d48..9bbd81d37 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -12,7 +12,10 @@ class SecurityRequirement: scopes: Optional[Sequence[str]] = None -LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], Union[str, int]], Callable[..., Any]] +LifespanDependantCacheKey: TypeAlias = Union[ + Tuple[Callable[..., Any], Union[str, int]], Callable[..., Any] +] + @dataclass class LifespanDependant: @@ -30,9 +33,9 @@ class LifespanDependant: elif self.name is not None: self.cache_key = (self.caller, self.name) else: - assert self.index is not None, ( - "Lifespan dependency must have an associated name or index." - ) + assert ( + self.index is not None + ), "Lifespan dependency must have an associated name or index." self.cache_key = (self.caller, self.index) @@ -49,8 +52,7 @@ class EndpointDependant: call: Optional[Callable[..., Any]] = None use_cache: bool = True index: Optional[int] = None - cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field( - init=False) + cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) path_params: List[ModelField] = field(default_factory=list) query_params: List[ModelField] = field(default_factory=list) header_params: List[ModelField] = field(default_factory=list) @@ -74,11 +76,11 @@ class EndpointDependant: def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]: lifespan_dependencies = cast( List[Union[EndpointDependant, LifespanDependant]], - self.lifespan_dependencies + self.lifespan_dependencies, ) endpoint_dependencies = cast( List[Union[EndpointDependant, LifespanDependant]], - self.endpoint_dependencies + self.endpoint_dependencies, ) return tuple(lifespan_dependencies + endpoint_dependencies) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 96f4a910f..96a00c12a 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -147,11 +147,7 @@ def get_param_sub_dependant( def get_parameterless_sub_dependant( - *, - depends: params.Depends, - path: str, - caller: Callable[..., Any], - index: int + *, depends: params.Depends, path: str, caller: Callable[..., Any], index: int ) -> Union[EndpointDependant, LifespanDependant]: assert callable( depends.dependency @@ -161,7 +157,7 @@ def get_parameterless_sub_dependant( dependency=depends.dependency, path=path, caller=caller, - index=index + index=index, ) @@ -181,7 +177,7 @@ def get_sub_dependant( call=dependency, name=name, use_cache=depends.use_cache, - index=index + index=index, ) elif depends.dependency_scope == "endpoint": security_requirement = None @@ -202,15 +198,15 @@ def get_sub_dependant( name=name, security_scopes=security_scopes, use_cache=depends.use_cache, - index=index + index=index, ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) return sub_dependant else: raise InvalidDependencyScope( - f"Dependency \"{name}\" of {caller} has an invalid " - f"scope: \"{depends.dependency_scope}\"" + f'Dependency "{name}" of {caller} has an invalid ' + f'scope: "{depends.dependency_scope}"' ) @@ -233,7 +229,7 @@ def get_flat_dependant( security_requirements=dependant.security_requirements.copy(), lifespan_dependencies=dependant.lifespan_dependencies.copy(), use_cache=dependant.use_cache, - path=dependant.path + path=dependant.path, ) for sub_dependant in dependant.endpoint_dependencies: if skip_repeats and sub_dependant.cache_key in visited: @@ -310,16 +306,12 @@ def get_lifespan_dependant( call: Callable[..., Any], name: Optional[str] = None, use_cache: bool = True, - index: Optional[int] = None + index: Optional[int] = None, ) -> LifespanDependant: dependency_signature = get_typed_signature(call) signature_params = dependency_signature.parameters dependant = LifespanDependant( - call=call, - name=name, - use_cache=use_cache, - caller=caller, - index=index + call=call, name=name, use_cache=use_cache, caller=caller, index=index ) for param_name, param in signature_params.items(): param_details = analyze_param( @@ -330,16 +322,17 @@ def get_lifespan_dependant( ) if param_details.depends is None: raise DependencyScopeConflict( - f"Lifespan scoped dependency \"{dependant.name}\" was defined " - f"with an invalid argument: \"{param_name}\" which is " - f"\"endpoint\" scoped. Lifespan scoped dependencies may only " - f"use lifespan scoped sub-dependencies.") + f'Lifespan scoped dependency "{dependant.name}" was defined ' + f'with an invalid argument: "{param_name}" which is ' + f'"endpoint" scoped. Lifespan scoped dependencies may only ' + f"use lifespan scoped sub-dependencies." + ) if param_details.depends.dependency_scope != "lifespan": raise DependencyScopeConflict( f"Lifespan scoped dependency {dependant.name} was defined with the " - f"sub-dependency \"{param_name}\" which is " - f"\"{param_details.depends.dependency_scope}\" scoped. " + f'sub-dependency "{param_name}" which is ' + f'"{param_details.depends.dependency_scope}" scoped. ' f"Lifespan scoped dependencies may only use lifespan scoped " f"sub-dependencies." ) @@ -350,7 +343,7 @@ def get_lifespan_dependant( name=param_name, call=param_details.depends.dependency, use_cache=param_details.depends.use_cache, - caller=call + caller=call, ) dependant.dependencies.append(sub_dependant) @@ -364,7 +357,7 @@ def get_endpoint_dependant( name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, - index: Optional[int] = None + index: Optional[int] = None, ) -> EndpointDependant: path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) @@ -375,7 +368,7 @@ def get_endpoint_dependant( path=path, security_scopes=security_scopes, use_cache=use_cache, - index=index + index=index, ) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names @@ -692,14 +685,12 @@ async def solve_lifespan_dependant( call = dependant.call dependant_to_solve = dependant if ( - dependency_overrides_provider - and dependency_overrides_provider.dependency_overrides + dependency_overrides_provider + and dependency_overrides_provider.dependency_overrides ): - call = getattr( - dependency_overrides_provider, - "dependency_overrides", - {} - ).get(dependant.call, dependant.call) + call = getattr(dependency_overrides_provider, "dependency_overrides", {}).get( + dependant.call, dependant.call + ) dependant_to_solve = get_lifespan_dependant( caller=dependant.caller, call=call, @@ -725,9 +716,7 @@ async def solve_lifespan_dependant( if is_gen_callable(call) or is_async_gen_callable(call): value = await solve_generator( - call=call, - stack=async_exit_stack, - sub_values=dependency_arguments + call=call, stack=async_exit_stack, sub_values=dependency_arguments ) elif is_coroutine_callable(call): value = await call(**dependency_arguments) @@ -773,7 +762,8 @@ async def solve_dependencies( try: lifespan_scoped_dependencies = request.state.__fastapi__[ - "lifespan_scoped_dependencies"] + "lifespan_scoped_dependencies" + ] except (AttributeError, KeyError) as e: raise UninitializedLifespanDependency( "FastAPI's internal lifespan was not initialized correctly." @@ -783,8 +773,8 @@ async def solve_dependencies( value = lifespan_scoped_dependencies[lifespan_sub_dependant.cache_key] except KeyError as e: raise UninitializedLifespanDependency( - f"Dependency \"{lifespan_sub_dependant.name}\" of " - f"`{dependant.call}` was not initialized correctly." + f'Dependency "{lifespan_sub_dependant.name}" of ' + f"`{dependant.call}` was not initialized correctly." ) from e values[lifespan_sub_dependant.name] = value diff --git a/fastapi/routing.py b/fastapi/routing.py index e11edaa13..376ad9c8b 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -400,13 +400,12 @@ class APIWebSocketRoute(routing.WebSocketRoute): self.name = get_name(endpoint) if name is None else name self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) - self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_endpoint_dependant( + path=self.path_format, call=self.endpoint + ) for i, depends in list(enumerate(self.dependencies))[::-1]: sub_dependant = get_parameterless_sub_dependant( - depends=depends, - path=self.path_format, - caller=self.__call__, - index=i + depends=depends, path=self.path_format, caller=self.__call__, index=i ) if depends.dependency_scope == "endpoint": assert isinstance(sub_dependant, EndpointDependant) @@ -566,13 +565,12 @@ class APIRoute(routing.Route): self.response_fields = {} assert callable(endpoint), "An endpoint must be a callable" - self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_endpoint_dependant( + path=self.path_format, call=self.endpoint + ) for i, depends in list(enumerate(self.dependencies))[::-1]: sub_dependant = get_parameterless_sub_dependant( - depends=depends, - path=self.path_format, - caller=self.__call__, - index=i + depends=depends, path=self.path_format, caller=self.__call__, index=i ) if depends.dependency_scope == "endpoint": assert isinstance(sub_dependant, EndpointDependant) diff --git a/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py index 430765aae..61d2fe3b2 100644 --- a/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py +++ b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py @@ -33,13 +33,13 @@ from tests.test_lifespan_scoped_dependencies.testing_utilities import ( def expect_correct_amount_of_dependency_activations( - *, - app: FastAPI, - dependency_factory: DependencyFactory, - override_dependency_factory: DependencyFactory, - urls_and_responses: List[Tuple[str, Any]], - expected_activation_times: int, - is_websocket: bool + *, + app: FastAPI, + dependency_factory: DependencyFactory, + override_dependency_factory: DependencyFactory, + urls_and_responses: List[Tuple[str, Any]], + expected_activation_times: int, + is_websocket: bool, ) -> None: assert dependency_factory.activation_times == 0 assert dependency_factory.deactivation_times == 0 @@ -62,17 +62,22 @@ def expect_correct_amount_of_dependency_activations( assert dependency_factory.activation_times == 0 assert dependency_factory.deactivation_times == 0 - assert override_dependency_factory.activation_times == expected_activation_times + assert ( + override_dependency_factory.activation_times + == expected_activation_times + ) assert override_dependency_factory.deactivation_times == 0 assert dependency_factory.activation_times == 0 assert override_dependency_factory.activation_times == expected_activation_times if dependency_factory.dependency_style not in ( - DependencyStyle.SYNC_FUNCTION, - DependencyStyle.ASYNC_FUNCTION + DependencyStyle.SYNC_FUNCTION, + DependencyStyle.ASYNC_FUNCTION, ): assert dependency_factory.deactivation_times == 0 - assert override_dependency_factory.deactivation_times == expected_activation_times + assert ( + override_dependency_factory.deactivation_times == expected_activation_times + ) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @@ -80,16 +85,10 @@ def expect_correct_amount_of_dependency_activations( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoint_dependencies( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): dependency_factory = DependencyFactory(dependency_style) - override_dependency_factory = DependencyFactory( - dependency_style, - value_offset=10 - ) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) app = FastAPI() @@ -108,14 +107,16 @@ def test_endpoint_dependencies( dependency_factory.get_dependency(), dependency_scope="lifespan", use_cache=use_cache, - ) + ), ], - expected_value=11 + expected_value=11, ) if routing_style == "router_endpoint": app.include_router(router) - app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) expect_correct_amount_of_dependency_activations( app=app, @@ -123,7 +124,7 @@ def test_endpoint_dependencies( override_dependency_factory=override_dependency_factory, urls_and_responses=[("/test", 11)] * 2, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) @@ -133,45 +134,40 @@ def test_endpoint_dependencies( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_router_dependencies( - dependency_style: DependencyStyle, - routing_style, - use_cache, - dependency_duplication, - is_websocket + dependency_style: DependencyStyle, + routing_style, + use_cache, + dependency_duplication, + is_websocket, ): - dependency_factory= DependencyFactory(dependency_style) - override_dependency_factory = DependencyFactory( - dependency_style, - value_offset=10 - ) + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) if routing_style == "app": app = FastAPI(dependencies=[depends] * dependency_duplication) create_endpoint_0_annotations( - router=app, - path="/test", - is_websocket=is_websocket + router=app, path="/test", is_websocket=is_websocket ) else: app = FastAPI() router = APIRouter(dependencies=[depends] * dependency_duplication) create_endpoint_0_annotations( - router=router, - path="/test", - is_websocket=is_websocket + router=router, path="/test", is_websocket=is_websocket ) app.include_router(router) - app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) expect_correct_amount_of_dependency_activations( app=app, @@ -179,31 +175,29 @@ def test_router_dependencies( override_dependency_factory=override_dependency_factory, urls_and_responses=[("/test", None)] * 2, expected_activation_times=1 if use_cache else dependency_duplication, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) def test_dependency_cache_in_same_dependency( - dependency_style: DependencyStyle, - routing_style, - use_cache, - main_dependency_scope: Literal["endpoint", "lifespan"], - is_websocket + dependency_style: DependencyStyle, + routing_style, + use_cache, + main_dependency_scope: Literal["endpoint", "lifespan"], + is_websocket, ): - dependency_factory= DependencyFactory(dependency_style) - override_dependency_factory = DependencyFactory( - dependency_style, - value_offset=10 - ) + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -215,8 +209,8 @@ def test_dependency_cache_in_same_dependency( router = APIRouter() async def dependency( - sub_dependency1: Annotated[int, depends], - sub_dependency2: Annotated[int, depends], + sub_dependency1: Annotated[int, depends], + sub_dependency2: Annotated[int, depends], ) -> List[int]: return [sub_dependency1, sub_dependency2] @@ -224,19 +218,22 @@ def test_dependency_cache_in_same_dependency( router=router, path="/test", is_websocket=is_websocket, - annotation=Annotated[List[int], Depends( - dependency, - use_cache=use_cache, - dependency_scope=main_dependency_scope, - )] + annotation=Annotated[ + List[int], + Depends( + dependency, + use_cache=use_cache, + dependency_scope=main_dependency_scope, + ), + ], ) if routing_style == "router": app.include_router(router) - app.dependency_overrides[ - dependency_factory.get_dependency() - ] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) if use_cache: expect_correct_amount_of_dependency_activations( @@ -248,7 +245,7 @@ def test_dependency_cache_in_same_dependency( dependency_factory=dependency_factory, override_dependency_factory=override_dependency_factory, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) else: expect_correct_amount_of_dependency_activations( @@ -260,7 +257,7 @@ def test_dependency_cache_in_same_dependency( dependency_factory=dependency_factory, override_dependency_factory=override_dependency_factory, expected_activation_times=2, - is_websocket=is_websocket + is_websocket=is_websocket, ) @@ -269,21 +266,15 @@ def test_dependency_cache_in_same_dependency( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_dependency_cache_in_same_endpoint( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) - override_dependency_factory = DependencyFactory( - dependency_style, - value_offset=10 - ) + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -309,9 +300,9 @@ def test_dependency_cache_in_same_endpoint( if routing_style == "router": app.include_router(router) - app.dependency_overrides[ - dependency_factory.get_dependency() - ] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) if use_cache: expect_correct_amount_of_dependency_activations( @@ -323,7 +314,7 @@ def test_dependency_cache_in_same_endpoint( dependency_factory=dependency_factory, override_dependency_factory=override_dependency_factory, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) else: expect_correct_amount_of_dependency_activations( @@ -335,29 +326,24 @@ def test_dependency_cache_in_same_endpoint( dependency_factory=dependency_factory, override_dependency_factory=override_dependency_factory, expected_activation_times=3, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_dependency_cache_in_different_endpoints( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) - override_dependency_factory = DependencyFactory( - dependency_style, - value_offset=10 - ) + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -392,8 +378,9 @@ def test_dependency_cache_in_different_endpoints( if routing_style == "router": app.include_router(router) - app.dependency_overrides[ - dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) if use_cache: expect_correct_amount_of_dependency_activations( @@ -407,7 +394,7 @@ def test_dependency_cache_in_different_endpoints( dependency_factory=dependency_factory, override_dependency_factory=override_dependency_factory, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) else: expect_correct_amount_of_dependency_activations( @@ -421,27 +408,23 @@ def test_dependency_cache_in_different_endpoints( dependency_factory=dependency_factory, override_dependency_factory=override_dependency_factory, expected_activation_times=5, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_no_cached_dependency( - dependency_style: DependencyStyle, - routing_style, - is_websocket + dependency_style: DependencyStyle, routing_style, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) - override_dependency_factory = DependencyFactory( - dependency_style, - value_offset=10 - ) + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=False + use_cache=False, ) app = FastAPI() @@ -462,8 +445,9 @@ def test_no_cached_dependency( if routing_style == "router": app.include_router(router) - app.dependency_overrides[ - dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) expect_correct_amount_of_dependency_activations( app=app, @@ -471,24 +455,27 @@ def test_no_cached_dependency( override_dependency_factory=override_dependency_factory, urls_and_responses=[("/test", 11)] * 2, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) -@pytest.mark.parametrize("annotation", [ - Annotated[str, Path()], - Annotated[str, Body()], - Annotated[str, Query()], - Annotated[str, Header()], - SecurityScopes, - Annotated[str, Cookie()], - Annotated[str, Form()], - Annotated[str, File()], - BackgroundTasks, -]) +@pytest.mark.parametrize( + "annotation", + [ + Annotated[str, Path()], + Annotated[str, Body()], + Annotated[str, Query()], + Annotated[str, Header()], + SecurityScopes, + Annotated[str, Cookie()], + Annotated[str, Form()], + Annotated[str, File()], + BackgroundTasks, + ], +) def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( - annotation, - is_websocket + annotation, is_websocket ): async def dependency_func() -> None: yield @@ -503,9 +490,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_paramete router=app, path="/test", is_websocket=is_websocket, - annotation=Annotated[None, - Depends(dependency_func, dependency_scope="lifespan") - ] + annotation=Annotated[ + None, Depends(dependency_func, dependency_scope="lifespan") + ], ) with pytest.raises(DependencyScopeConflict): @@ -516,20 +503,16 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_paramete @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies( - dependency_style: DependencyStyle, - is_websocket + dependency_style: DependencyStyle, is_websocket ): dependency_factory = DependencyFactory(dependency_style) - override_dependency_factory = DependencyFactory( - dependency_style, - value_offset=10 - ) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) async def lifespan_scoped_dependency( - param: Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - )] + param: Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], ) -> AsyncGenerator[int, None]: yield param @@ -540,12 +523,13 @@ def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_sco path="/test", is_websocket=is_websocket, annotation=Annotated[ - int, - Depends(lifespan_scoped_dependency, dependency_scope="lifespan") + int, Depends(lifespan_scoped_dependency, dependency_scope="lifespan") ], ) - app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) expect_correct_amount_of_dependency_activations( app=app, @@ -553,15 +537,14 @@ def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_sco override_dependency_factory=override_dependency_factory, expected_activation_times=1, urls_and_responses=[("/test", 11)] * 2, - is_websocket=is_websocket + is_websocket=is_websocket, ) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("depends_class", [Depends, Security]) def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( - depends_class, - is_websocket + depends_class, is_websocket ): async def sub_dependency() -> None: pass @@ -569,7 +552,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen async def dependency_func() -> None: yield - async def override_dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: + async def override_dependency_func( + param: Annotated[None, depends_class(sub_dependency)], + ) -> None: yield app = FastAPI() @@ -578,7 +563,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen router=app, path="/test", is_websocket=is_websocket, - annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] + annotation=Annotated[ + None, Depends(dependency_func, dependency_scope="lifespan") + ], ) app.dependency_overrides[dependency_func] = override_dependency_func @@ -593,12 +580,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_bad_override_lifespan_scoped_dependencies( - use_cache, - dependency_style: DependencyStyle, - routing_style, - is_websocket + use_cache, dependency_style: DependencyStyle, routing_style, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) override_dependency_factory = DependencyFactory(dependency_style, should_error=True) depends = Depends( @@ -619,13 +603,15 @@ def test_bad_override_lifespan_scoped_dependencies( router=router, path="/test", is_websocket=is_websocket, - annotation=Annotated[int, depends] + annotation=Annotated[int, depends], ) if routing_style == "router_endpoint": app.include_router(router) - app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) with pytest.raises(IntentionallyBadDependency) as exception_info: with TestClient(app): diff --git a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py index ccf8d896a..66caf065a 100644 --- a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py +++ b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py @@ -40,12 +40,12 @@ from tests.test_lifespan_scoped_dependencies.testing_utilities import ( def expect_correct_amount_of_dependency_activations( - *, - app: FastAPI, - dependency_factory: DependencyFactory, - urls_and_responses: List[Tuple[str, Any]], - expected_activation_times: int, - is_websocket: bool + *, + app: FastAPI, + dependency_factory: DependencyFactory, + urls_and_responses: List[Tuple[str, Any]], + expected_activation_times: int, + is_websocket: bool, ) -> None: assert dependency_factory.activation_times == 0 assert dependency_factory.deactivation_times == 0 @@ -64,21 +64,23 @@ def expect_correct_amount_of_dependency_activations( assert dependency_factory.activation_times == expected_activation_times if dependency_factory.dependency_style not in ( - DependencyStyle.SYNC_FUNCTION, - DependencyStyle.ASYNC_FUNCTION + DependencyStyle.SYNC_FUNCTION, + DependencyStyle.ASYNC_FUNCTION, ): assert dependency_factory.deactivation_times == expected_activation_times @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) -@pytest.mark.parametrize("use_cache", [True, False], ids=["With Cache", "Without Cache"]) +@pytest.mark.parametrize( + "use_cache", [True, False], ids=["With Cache", "Without Cache"] +) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoint_dependencies( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket: bool, + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket: bool, ): dependency_factory = DependencyFactory(dependency_style) @@ -93,12 +95,15 @@ def test_endpoint_dependencies( router=router, path="/test", is_websocket=is_websocket, - annotation=Annotated[None, Depends( + annotation=Annotated[ + None, + Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", use_cache=use_cache, - )], - expected_value=1 + ), + ], + expected_value=1, ) if routing_style == "router_endpoint": @@ -109,45 +114,42 @@ def test_endpoint_dependencies( dependency_factory=dependency_factory, urls_and_responses=[("/test", 1)] * 2, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("dependency_duplication", [1, 2]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_router_dependencies( - dependency_style: DependencyStyle, - routing_style, - use_cache, - dependency_duplication, - is_websocket: bool, + dependency_style: DependencyStyle, + routing_style, + use_cache, + dependency_duplication, + is_websocket: bool, ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) if routing_style == "app": app = FastAPI(dependencies=[depends] * dependency_duplication) create_endpoint_0_annotations( - router=app, - path="/test", - is_websocket=is_websocket + router=app, path="/test", is_websocket=is_websocket ) else: app = FastAPI() router = APIRouter(dependencies=[depends] * dependency_duplication) create_endpoint_0_annotations( - router=router, - path="/test", - is_websocket=is_websocket + router=router, path="/test", is_websocket=is_websocket ) app.include_router(router) @@ -157,27 +159,28 @@ def test_router_dependencies( dependency_factory=dependency_factory, urls_and_responses=[("/test", None)] * 2, expected_activation_times=1 if use_cache else dependency_duplication, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) def test_dependency_cache_in_same_dependency( - dependency_style: DependencyStyle, - routing_style, - use_cache, - main_dependency_scope: Literal["endpoint", "lifespan"], - is_websocket: bool, + dependency_style: DependencyStyle, + routing_style, + use_cache, + main_dependency_scope: Literal["endpoint", "lifespan"], + is_websocket: bool, ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -189,8 +192,8 @@ def test_dependency_cache_in_same_dependency( router = APIRouter() async def dependency( - sub_dependency1: Annotated[int, depends], - sub_dependency2: Annotated[int, depends], + sub_dependency1: Annotated[int, depends], + sub_dependency2: Annotated[int, depends], ) -> List[int]: return [sub_dependency1, sub_dependency2] @@ -198,11 +201,14 @@ def test_dependency_cache_in_same_dependency( router=router, path="/test", is_websocket=is_websocket, - annotation=Annotated[List[int], Depends( - dependency, - use_cache=use_cache, - dependency_scope=main_dependency_scope, - )] + annotation=Annotated[ + List[int], + Depends( + dependency, + use_cache=use_cache, + dependency_scope=main_dependency_scope, + ), + ], ) if routing_style == "router": @@ -217,7 +223,7 @@ def test_dependency_cache_in_same_dependency( ], dependency_factory=dependency_factory, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) else: expect_correct_amount_of_dependency_activations( @@ -228,7 +234,7 @@ def test_dependency_cache_in_same_dependency( ], dependency_factory=dependency_factory, expected_activation_times=2, - is_websocket=is_websocket + is_websocket=is_websocket, ) @@ -237,17 +243,14 @@ def test_dependency_cache_in_same_dependency( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_dependency_cache_in_same_endpoint( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -267,7 +270,7 @@ def test_dependency_cache_in_same_endpoint( is_websocket=is_websocket, annotation1=Annotated[int, depends], annotation2=Annotated[int, depends], - annotation3=Annotated[int, Depends(endpoint_dependency)] + annotation3=Annotated[int, Depends(endpoint_dependency)], ) if routing_style == "router": @@ -282,7 +285,7 @@ def test_dependency_cache_in_same_endpoint( ], dependency_factory=dependency_factory, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) else: expect_correct_amount_of_dependency_activations( @@ -293,25 +296,23 @@ def test_dependency_cache_in_same_endpoint( ], dependency_factory=dependency_factory, expected_activation_times=3, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_dependency_cache_in_different_endpoints( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -331,7 +332,7 @@ def test_dependency_cache_in_different_endpoints( is_websocket=is_websocket, annotation1=Annotated[int, depends], annotation2=Annotated[int, depends], - annotation3=Annotated[int, Depends(endpoint_dependency)] + annotation3=Annotated[int, Depends(endpoint_dependency)], ) create_endpoint_3_annotations( @@ -340,7 +341,7 @@ def test_dependency_cache_in_different_endpoints( is_websocket=is_websocket, annotation1=Annotated[int, depends], annotation2=Annotated[int, depends], - annotation3=Annotated[int, Depends(endpoint_dependency)] + annotation3=Annotated[int, Depends(endpoint_dependency)], ) if routing_style == "router": @@ -357,7 +358,7 @@ def test_dependency_cache_in_different_endpoints( ], dependency_factory=dependency_factory, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) else: expect_correct_amount_of_dependency_activations( @@ -370,23 +371,24 @@ def test_dependency_cache_in_different_endpoints( ], dependency_factory=dependency_factory, expected_activation_times=5, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_no_cached_dependency( - dependency_style: DependencyStyle, - routing_style, - is_websocket, + dependency_style: DependencyStyle, + routing_style, + is_websocket, ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=False + use_cache=False, ) app = FastAPI() @@ -402,7 +404,7 @@ def test_no_cached_dependency( path="/test", is_websocket=is_websocket, annotation=Annotated[int, depends], - expected_value=1 + expected_value=1, ) if routing_style == "router": @@ -413,25 +415,27 @@ def test_no_cached_dependency( dependency_factory=dependency_factory, urls_and_responses=[("/test", 1)] * 2, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) -@pytest.mark.parametrize("annotation", [ - Annotated[str, Path()], - Annotated[str, Body()], - Annotated[str, Query()], - Annotated[str, Header()], - SecurityScopes, - Annotated[str, Cookie()], - Annotated[str, Form()], - Annotated[str, File()], - BackgroundTasks, -]) +@pytest.mark.parametrize( + "annotation", + [ + Annotated[str, Path()], + Annotated[str, Body()], + Annotated[str, Query()], + Annotated[str, Header()], + SecurityScopes, + Annotated[str, Cookie()], + Annotated[str, Form()], + Annotated[str, File()], + BackgroundTasks, + ], +) def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( - annotation, - is_websocket + annotation, is_websocket ): async def dependency_func(param: annotation) -> None: yield @@ -444,8 +448,7 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( path="/test", is_websocket=is_websocket, annotation=Annotated[ - None, - Depends(dependency_func, dependency_scope="lifespan") + None, Depends(dependency_func, dependency_scope="lifespan") ], ) @@ -453,16 +456,15 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( - dependency_style: DependencyStyle, - is_websocket + dependency_style: DependencyStyle, is_websocket ): dependency_factory = DependencyFactory(dependency_style) async def lifespan_scoped_dependency( - param: Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - )] + param: Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], ) -> AsyncGenerator[int, None]: yield param @@ -473,7 +475,7 @@ def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( path="/test", is_websocket=is_websocket, annotation=Annotated[int, Depends(lifespan_scoped_dependency)], - expected_value=1 + expected_value=1, ) expect_correct_amount_of_dependency_activations( @@ -481,24 +483,22 @@ def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( dependency_factory=dependency_factory, expected_activation_times=1, urls_and_responses=[("/test", 1)] * 2, - is_websocket=is_websocket + is_websocket=is_websocket, ) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) -@pytest.mark.parametrize([ - "dependency_style", - "supports_teardown" -], [ - (DependencyStyle.SYNC_FUNCTION, False), - (DependencyStyle.ASYNC_FUNCTION, False), - (DependencyStyle.SYNC_GENERATOR, True), - (DependencyStyle.ASYNC_GENERATOR, True), -]) +@pytest.mark.parametrize( + ["dependency_style", "supports_teardown"], + [ + (DependencyStyle.SYNC_FUNCTION, False), + (DependencyStyle.ASYNC_FUNCTION, False), + (DependencyStyle.SYNC_GENERATOR, True), + (DependencyStyle.ASYNC_GENERATOR, True), + ], +) def test_the_same_dependency_can_work_in_different_scopes( - dependency_style: DependencyStyle, - supports_teardown, - is_websocket + dependency_style: DependencyStyle, supports_teardown, is_websocket ): dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -507,14 +507,14 @@ def test_the_same_dependency_can_work_in_different_scopes( router=app, path="/test", is_websocket=is_websocket, - annotation1=Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="endpoint" - )], - annotation2=Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - )], + annotation1=Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="endpoint"), + ], + annotation2=Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], ) if is_websocket: get_response = use_websocket @@ -548,17 +548,20 @@ def test_the_same_dependency_can_work_in_different_scopes( assert dependency_factory.deactivation_times == 0 -@pytest.mark.parametrize("lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"]) +@pytest.mark.parametrize( + "lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"] +) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( - dependency_style: DependencyStyle, - is_websocket, - lifespan_style: Literal["lifespan_function", "lifespan_events"] + dependency_style: DependencyStyle, + is_websocket, + lifespan_style: Literal["lifespan_function", "lifespan_events"], ): lifespan_started = False lifespan_ended = False if lifespan_style == "lifespan_generator": + @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, int], None]: nonlocal lifespan_started @@ -571,6 +574,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( elif lifespan_style == "events_decorator": app = FastAPI() with warnings.catch_warnings(action="ignore", category=DeprecationWarning): + @app.on_event("startup") async def startup() -> None: nonlocal lifespan_started @@ -581,6 +585,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( nonlocal lifespan_ended lifespan_ended = True elif lifespan_style == "events_constructor": + async def startup() -> None: nonlocal lifespan_started lifespan_started = True @@ -588,6 +593,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( async def shutdown() -> None: nonlocal lifespan_ended lifespan_ended = True + app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) else: assert_never(lifespan_style) @@ -598,11 +604,11 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( router=app, path="/test", is_websocket=is_websocket, - annotation=Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - )], - expected_value=1 + annotation=Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], + expected_value=1, ) expect_correct_amount_of_dependency_activations( @@ -610,20 +616,22 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( dependency_factory=dependency_factory, expected_activation_times=1, urls_and_responses=[("/test", 1)] * 2, - is_websocket=is_websocket + is_websocket=is_websocket, ) assert lifespan_started and lifespan_ended + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("depends_class", [Depends, Security]) def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( - depends_class, - is_websocket + depends_class, is_websocket ): async def sub_dependency() -> None: pass - async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: + async def dependency_func( + param: Annotated[None, depends_class(sub_dependency)], + ) -> None: yield app = FastAPI() @@ -633,20 +641,20 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( router=app, path="/test", is_websocket=is_websocket, - annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")], + annotation=Annotated[ + None, Depends(dependency_func, dependency_scope="lifespan") + ], ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_dependencies_must_provide_correct_dependency_scope( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -656,19 +664,21 @@ def test_dependencies_must_provide_correct_dependency_scope( router = APIRouter() with pytest.raises( - InvalidDependencyScope, - match=r'Dependency "value" of .* has an invalid scope: ' - r'"incorrect"' + InvalidDependencyScope, + match=r'Dependency "value" of .* has an invalid scope: ' r'"incorrect"', ): create_endpoint_1_annotation( router=router, path="/test", is_websocket=is_websocket, - annotation=Annotated[None, Depends( - dependency_factory.get_dependency(), - dependency_scope="incorrect", - use_cache=use_cache, - )] + annotation=Annotated[ + None, + Depends( + dependency_factory.get_dependency(), + dependency_scope="incorrect", + use_cache=use_cache, + ), + ], ) @@ -677,12 +687,9 @@ def test_dependencies_must_provide_correct_dependency_scope( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoints_report_incorrect_dependency_scope( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -705,7 +712,7 @@ def test_endpoints_report_incorrect_dependency_scope( router=router, path="/test", is_websocket=is_websocket, - annotation=Annotated[int, depends] + annotation=Annotated[int, depends], ) @@ -714,12 +721,9 @@ def test_endpoints_report_incorrect_dependency_scope( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoints_report_uninitialized_dependency( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -739,7 +743,7 @@ def test_endpoints_report_uninitialized_dependency( path="/test", is_websocket=is_websocket, annotation=Annotated[int, depends], - expected_value=1 + expected_value=1, ) if routing_style == "router_endpoint": @@ -757,7 +761,9 @@ def test_endpoints_report_uninitialized_dependency( else: client.post("/test") finally: - client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies + client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = ( + dependencies + ) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @@ -765,10 +771,7 @@ def test_endpoints_report_uninitialized_dependency( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoints_report_uninitialized_internal_lifespan( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): dependency_factory = DependencyFactory(dependency_style) @@ -790,7 +793,7 @@ def test_endpoints_report_uninitialized_internal_lifespan( path="/test", is_websocket=is_websocket, annotation=Annotated[int, depends], - expected_value=1 + expected_value=1, ) if routing_style == "router_endpoint": @@ -816,12 +819,9 @@ def test_endpoints_report_uninitialized_internal_lifespan( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_bad_lifespan_scoped_dependencies( - use_cache, - dependency_style: DependencyStyle, - routing_style, - is_websocket + use_cache, dependency_style: DependencyStyle, routing_style, is_websocket ): - dependency_factory= DependencyFactory(dependency_style, should_error=True) + dependency_factory = DependencyFactory(dependency_style, should_error=True) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", @@ -841,7 +841,7 @@ def test_bad_lifespan_scoped_dependencies( path="/test", is_websocket=is_websocket, annotation=Annotated[int, depends], - expected_value=1 + expected_value=1, ) if routing_style == "router_endpoint": diff --git a/tests/test_lifespan_scoped_dependencies/testing_utilities.py b/tests/test_lifespan_scoped_dependencies/testing_utilities.py index e733205f5..1f0f100a8 100644 --- a/tests/test_lifespan_scoped_dependencies/testing_utilities.py +++ b/tests/test_lifespan_scoped_dependencies/testing_utilities.py @@ -5,7 +5,7 @@ from fastapi import APIRouter, FastAPI, WebSocket from starlette.testclient import TestClient from starlette.websockets import WebSocketDisconnect -T = TypeVar('T') +T = TypeVar("T") class DependencyStyle(StrEnum): @@ -21,10 +21,11 @@ class IntentionallyBadDependency(Exception): class DependencyFactory: def __init__( - self, - dependency_style: DependencyStyle, *, - should_error: bool = False, - value_offset: int = 0, + self, + dependency_style: DependencyStyle, + *, + should_error: bool = False, + value_offset: int = 0, ): self.activation_times = 0 self.deactivation_times = 0 @@ -90,12 +91,13 @@ def use_websocket(client: TestClient, url: str) -> Any: def create_endpoint_0_annotations( - *, - router: Union[APIRouter, FastAPI], - path: str, - is_websocket: bool, + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, ) -> None: if is_websocket: + @router.websocket(path) async def endpoint(websocket: WebSocket) -> None: await websocket.accept() @@ -104,25 +106,24 @@ def create_endpoint_0_annotations( except WebSocketDisconnect: pass else: + @router.post(path) async def endpoint() -> None: return None def create_endpoint_1_annotation( - *, - router: Union[APIRouter, FastAPI], - path: str, - is_websocket: bool, - annotation: Any, - expected_value: Any = None + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, + annotation: Any, + expected_value: Any = None, ) -> None: if is_websocket: + @router.websocket(path) - async def endpoint( - websocket: WebSocket, - value: annotation - ) -> None: + async def endpoint(websocket: WebSocket, value: annotation) -> None: if expected_value is not None: assert value == expected_value @@ -132,29 +133,30 @@ def create_endpoint_1_annotation( except WebSocketDisconnect: pass else: + @router.post(path) - async def endpoint( - value: annotation - ) -> None: + async def endpoint(value: annotation) -> None: if expected_value is not None: assert value == expected_value return value + def create_endpoint_2_annotations( - *, - router: Union[APIRouter, FastAPI], - path: str, - is_websocket: bool, - annotation1: Any, - annotation2: Any, + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, + annotation1: Any, + annotation2: Any, ) -> None: if is_websocket: + @router.websocket(path) async def endpoint( - websocket: WebSocket, - value1: annotation1, - value2: annotation2, + websocket: WebSocket, + value1: annotation1, + value2: annotation2, ) -> None: await websocket.accept() try: @@ -162,30 +164,32 @@ def create_endpoint_2_annotations( except WebSocketDisconnect: await websocket.close() else: + @router.post(path) async def endpoint( - value1: annotation1, - value2: annotation2, + value1: annotation1, + value2: annotation2, ) -> list[Any]: return [value1, value2] def create_endpoint_3_annotations( - *, - router: Union[APIRouter, FastAPI], - path: str, - is_websocket: bool, - annotation1: Any, - annotation2: Any, - annotation3: Any + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, + annotation1: Any, + annotation2: Any, + annotation3: Any, ) -> None: if is_websocket: + @router.websocket(path) async def endpoint( - websocket: WebSocket, - value1: annotation1, - value2: annotation2, - value3: annotation3 + websocket: WebSocket, + value1: annotation1, + value2: annotation2, + value3: annotation3, ) -> None: await websocket.accept() try: @@ -193,10 +197,9 @@ def create_endpoint_3_annotations( except WebSocketDisconnect: await websocket.close() else: + @router.post(path) async def endpoint( - value1: annotation1, - value2: annotation2, - value3: annotation3 + value1: annotation1, value2: annotation2, value3: annotation3 ) -> list[Any]: return [value1, value2, value3] From e7ab9579238f46b1aa94e635a0c6792fc27ab72a Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 9 Nov 2024 09:41:17 +0200 Subject: [PATCH 05/29] Fixed tests compatibility with older python versions --- tests/test_lifespan_scoped_dependencies.py | 0 .../testing_utilities.py | 12 ++++++------ 2 files changed, 6 insertions(+), 6 deletions(-) delete mode 100644 tests/test_lifespan_scoped_dependencies.py diff --git a/tests/test_lifespan_scoped_dependencies.py b/tests/test_lifespan_scoped_dependencies.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test_lifespan_scoped_dependencies/testing_utilities.py b/tests/test_lifespan_scoped_dependencies/testing_utilities.py index 1f0f100a8..a6ad869c8 100644 --- a/tests/test_lifespan_scoped_dependencies/testing_utilities.py +++ b/tests/test_lifespan_scoped_dependencies/testing_utilities.py @@ -1,4 +1,4 @@ -from enum import StrEnum, auto +from enum import Enum from typing import Any, AsyncGenerator, Generator, TypeVar, Union, assert_never from fastapi import APIRouter, FastAPI, WebSocket @@ -8,11 +8,11 @@ from starlette.websockets import WebSocketDisconnect T = TypeVar("T") -class DependencyStyle(StrEnum): - SYNC_FUNCTION = auto() - ASYNC_FUNCTION = auto() - SYNC_GENERATOR = auto() - ASYNC_GENERATOR = auto() +class DependencyStyle(str, Enum): + SYNC_FUNCTION = "sync_function" + ASYNC_FUNCTION = "async_function" + SYNC_GENERATOR = "sync_generator" + ASYNC_GENERATOR = "async_generator" class IntentionallyBadDependency(Exception): From 3c5aeaa69a7f466777de338a5ce1d2c2dbe59a23 Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 9 Nov 2024 09:44:14 +0200 Subject: [PATCH 06/29] Fixed import of assert_never from typing --- tests/test_lifespan_scoped_dependencies/testing_utilities.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_lifespan_scoped_dependencies/testing_utilities.py b/tests/test_lifespan_scoped_dependencies/testing_utilities.py index a6ad869c8..ac1fc0222 100644 --- a/tests/test_lifespan_scoped_dependencies/testing_utilities.py +++ b/tests/test_lifespan_scoped_dependencies/testing_utilities.py @@ -1,5 +1,6 @@ from enum import Enum -from typing import Any, AsyncGenerator, Generator, TypeVar, Union, assert_never +from typing import Any, AsyncGenerator, Generator, TypeVar, Union +from typing_extensions import assert_never from fastapi import APIRouter, FastAPI, WebSocket from starlette.testclient import TestClient From 9f2bd41c20ed3da701059808b5f7ccad620c662c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 9 Nov 2024 07:44:25 +0000 Subject: [PATCH 07/29] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_lifespan_scoped_dependencies/testing_utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lifespan_scoped_dependencies/testing_utilities.py b/tests/test_lifespan_scoped_dependencies/testing_utilities.py index ac1fc0222..e4fd3f5cd 100644 --- a/tests/test_lifespan_scoped_dependencies/testing_utilities.py +++ b/tests/test_lifespan_scoped_dependencies/testing_utilities.py @@ -1,10 +1,10 @@ from enum import Enum from typing import Any, AsyncGenerator, Generator, TypeVar, Union -from typing_extensions import assert_never from fastapi import APIRouter, FastAPI, WebSocket from starlette.testclient import TestClient from starlette.websockets import WebSocketDisconnect +from typing_extensions import assert_never T = TypeVar("T") From 6c923ac4ec25033ba4e9ecd65d43380d2686c4d3 Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 9 Nov 2024 09:51:56 +0200 Subject: [PATCH 08/29] Fixed usage of warnings.catch_warnings in a way that is not compatible with older python versions --- tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py index 66caf065a..ab52af5e1 100644 --- a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py +++ b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py @@ -573,7 +573,8 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( app = FastAPI(lifespan=lifespan) elif lifespan_style == "events_decorator": app = FastAPI() - with warnings.catch_warnings(action="ignore", category=DeprecationWarning): + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") @app.on_event("startup") async def startup() -> None: From d38364a2f07f5e944b35d8bb381f6e8cceb239e3 Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 9 Nov 2024 10:03:11 +0200 Subject: [PATCH 09/29] Fixed tests compatibility with python 3.8 --- .../testing_utilities.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_lifespan_scoped_dependencies/testing_utilities.py b/tests/test_lifespan_scoped_dependencies/testing_utilities.py index e4fd3f5cd..373336562 100644 --- a/tests/test_lifespan_scoped_dependencies/testing_utilities.py +++ b/tests/test_lifespan_scoped_dependencies/testing_utilities.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, AsyncGenerator, Generator, TypeVar, Union +from typing import Any, AsyncGenerator, Generator, List, TypeVar, Union from fastapi import APIRouter, FastAPI, WebSocket from starlette.testclient import TestClient @@ -136,7 +136,7 @@ def create_endpoint_1_annotation( else: @router.post(path) - async def endpoint(value: annotation) -> None: + async def endpoint(value: annotation) -> Any: if expected_value is not None: assert value == expected_value @@ -170,7 +170,7 @@ def create_endpoint_2_annotations( async def endpoint( value1: annotation1, value2: annotation2, - ) -> list[Any]: + ) -> List[Any]: return [value1, value2] @@ -202,5 +202,5 @@ def create_endpoint_3_annotations( @router.post(path) async def endpoint( value1: annotation1, value2: annotation2, value3: annotation3 - ) -> list[Any]: + ) -> List[Any]: return [value1, value2, value3] From 3b27561d3dc2b40ec48417d9b5e2b85958ca50ca Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 9 Nov 2024 11:41:58 +0200 Subject: [PATCH 10/29] Fixed tests having incorrect ids. Fixed an inconsistent results for tests which check websocket-scoped teardowns before the teardown actually happened on the server side. --- .../test_dependency_overrides.py | 24 +++++----- .../test_endpoint_usage.py | 45 ++++++++++++------- .../testing_utilities.py | 4 ++ 3 files changed, 47 insertions(+), 26 deletions(-) diff --git a/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py index 61d2fe3b2..9c7800ee2 100644 --- a/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py +++ b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py @@ -13,6 +13,8 @@ from fastapi import ( Header, Path, Query, + Request, + WebSocket, ) from fastapi.exceptions import DependencyScopeConflict from fastapi.params import Security @@ -80,7 +82,7 @@ def expect_correct_amount_of_dependency_activations( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) @@ -128,7 +130,7 @@ def test_endpoint_dependencies( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("dependency_duplication", [1, 2]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @@ -179,7 +181,7 @@ def test_router_dependencies( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) @@ -261,7 +263,7 @@ def test_dependency_cache_in_same_dependency( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) @@ -330,7 +332,7 @@ def test_dependency_cache_in_same_endpoint( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) @@ -412,7 +414,7 @@ def test_dependency_cache_in_different_endpoints( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_no_cached_dependency( @@ -459,7 +461,7 @@ def test_no_cached_dependency( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize( "annotation", [ @@ -472,6 +474,8 @@ def test_no_cached_dependency( Annotated[str, Form()], Annotated[str, File()], BackgroundTasks, + Request, + WebSocket, ], ) def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( @@ -500,7 +504,7 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_paramete pass -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies( dependency_style: DependencyStyle, is_websocket @@ -541,7 +545,7 @@ def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_sco ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("depends_class", [Depends, Security]) def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( depends_class, is_websocket @@ -575,7 +579,7 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen pass -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) diff --git a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py index ab52af5e1..54c5ed855 100644 --- a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py +++ b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py @@ -1,5 +1,6 @@ import warnings from contextlib import asynccontextmanager +from time import sleep from typing import Any, AsyncGenerator, Dict, List, Tuple import pytest @@ -15,6 +16,8 @@ from fastapi import ( Header, Path, Query, + Request, + WebSocket, ) from fastapi.exceptions import ( DependencyScopeConflict, @@ -70,7 +73,7 @@ def expect_correct_amount_of_dependency_activations( assert dependency_factory.deactivation_times == expected_activation_times -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize( "use_cache", [True, False], ids=["With Cache", "Without Cache"] ) @@ -118,7 +121,7 @@ def test_endpoint_dependencies( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("dependency_duplication", [1, 2]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @@ -163,7 +166,7 @@ def test_router_dependencies( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) @@ -238,7 +241,7 @@ def test_dependency_cache_in_same_dependency( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) @@ -300,7 +303,7 @@ def test_dependency_cache_in_same_endpoint( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) @@ -375,7 +378,7 @@ def test_dependency_cache_in_different_endpoints( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_no_cached_dependency( @@ -419,7 +422,7 @@ def test_no_cached_dependency( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize( "annotation", [ @@ -432,6 +435,8 @@ def test_no_cached_dependency( Annotated[str, Form()], Annotated[str, File()], BackgroundTasks, + Request, + WebSocket ], ) def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( @@ -453,7 +458,7 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( dependency_style: DependencyStyle, is_websocket @@ -487,7 +492,7 @@ def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize( ["dependency_style", "supports_teardown"], [ @@ -530,6 +535,10 @@ def test_the_same_dependency_can_work_in_different_scopes( assert get_response(client, "/test") == [2, 1] assert dependency_factory.activation_times == 2 if supports_teardown: + if is_websocket: + # Websockets teardown might take some time after the test client + # has disconnected + sleep(0.1) assert dependency_factory.deactivation_times == 1 else: assert dependency_factory.deactivation_times == 0 @@ -537,6 +546,10 @@ def test_the_same_dependency_can_work_in_different_scopes( assert get_response(client, "/test") == [3, 1] assert dependency_factory.activation_times == 3 if supports_teardown: + if is_websocket: + # Websockets teardown might take some time after the test client + # has disconnected + sleep(0.1) assert dependency_factory.deactivation_times == 2 else: assert dependency_factory.deactivation_times == 0 @@ -551,7 +564,7 @@ def test_the_same_dependency_can_work_in_different_scopes( @pytest.mark.parametrize( "lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"] ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( dependency_style: DependencyStyle, @@ -622,7 +635,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( assert lifespan_started and lifespan_ended -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("depends_class", [Depends, Security]) def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( depends_class, is_websocket @@ -648,7 +661,7 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) @@ -683,7 +696,7 @@ def test_dependencies_must_provide_correct_dependency_scope( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) @@ -717,7 +730,7 @@ def test_endpoints_report_incorrect_dependency_scope( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) @@ -767,7 +780,7 @@ def test_endpoints_report_uninitialized_dependency( ) -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) @@ -815,7 +828,7 @@ def test_endpoints_report_uninitialized_internal_lifespan( client.app_state["__fastapi__"] = internal_state -@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) diff --git a/tests/test_lifespan_scoped_dependencies/testing_utilities.py b/tests/test_lifespan_scoped_dependencies/testing_utilities.py index 373336562..77054f941 100644 --- a/tests/test_lifespan_scoped_dependencies/testing_utilities.py +++ b/tests/test_lifespan_scoped_dependencies/testing_utilities.py @@ -1,3 +1,4 @@ +import threading from enum import Enum from typing import Any, AsyncGenerator, Generator, List, TypeVar, Union @@ -33,6 +34,7 @@ class DependencyFactory: self.dependency_style = dependency_style self._should_error = should_error self._value_offset = value_offset + self._event = threading.Event() def get_dependency(self): if self.dependency_style == DependencyStyle.SYNC_FUNCTION: @@ -56,6 +58,7 @@ class DependencyFactory: yield self.activation_times + self._value_offset self.deactivation_times += 1 + self._event.set() def _synchronous_generator_dependency(self) -> Generator[T, None, None]: self.activation_times += 1 @@ -64,6 +67,7 @@ class DependencyFactory: yield self.activation_times + self._value_offset self.deactivation_times += 1 + self._event.set() async def _asynchronous_function_dependency(self) -> T: self.activation_times += 1 From 179d6534a8aca962ccaf8109c1e3345ac76c99b1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 9 Nov 2024 09:42:10 +0000 Subject: [PATCH 11/29] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py index 54c5ed855..a6fdea9e4 100644 --- a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py +++ b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py @@ -436,7 +436,7 @@ def test_no_cached_dependency( Annotated[str, File()], BackgroundTasks, Request, - WebSocket + WebSocket, ], ) def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( From 01e9da1ca212ee84cdde0f0ec3106e268794a2a7 Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 9 Nov 2024 12:32:58 +0200 Subject: [PATCH 12/29] Added coverage --- fastapi/dependencies/utils.py | 5 +- fastapi/lifespan.py | 2 +- fastapi/routing.py | 16 ++--- .../test_dependency_overrides.py | 10 +-- .../test_endpoint_usage.py | 71 +++++++++++++++++-- .../testing_utilities.py | 29 ++------ 6 files changed, 85 insertions(+), 48 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 96a00c12a..94e983e1e 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -388,10 +388,9 @@ def get_endpoint_dependant( ) if isinstance(sub_dependant, EndpointDependant): dependant.endpoint_dependencies.append(sub_dependant) - elif isinstance(sub_dependant, LifespanDependant): - dependant.lifespan_dependencies.append(sub_dependant) else: - assert_never(sub_dependant) + assert isinstance(sub_dependant, LifespanDependant) + dependant.lifespan_dependencies.append(sub_dependant) continue if add_non_field_param_to_dependency( param_name=param_name, diff --git a/fastapi/lifespan.py b/fastapi/lifespan.py index 9b8afba03..7d53fc00a 100644 --- a/fastapi/lifespan.py +++ b/fastapi/lifespan.py @@ -7,7 +7,7 @@ from fastapi.dependencies.models import LifespanDependant, LifespanDependantCach from fastapi.dependencies.utils import solve_lifespan_dependant from fastapi.routing import APIRoute, APIWebSocketRoute -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: nocover from fastapi import FastAPI diff --git a/fastapi/routing.py b/fastapi/routing.py index 376ad9c8b..7a7344e69 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -73,7 +73,7 @@ from starlette.routing import ( from starlette.routing import Mount as Mount # noqa from starlette.types import AppType, ASGIApp, Lifespan, Scope from starlette.websockets import WebSocket -from typing_extensions import Annotated, Doc, assert_never, deprecated +from typing_extensions import Annotated, Doc, deprecated def _prepare_response_content( @@ -407,14 +407,12 @@ class APIWebSocketRoute(routing.WebSocketRoute): sub_dependant = get_parameterless_sub_dependant( depends=depends, path=self.path_format, caller=self.__call__, index=i ) - if depends.dependency_scope == "endpoint": + if isinstance(sub_dependant, EndpointDependant): assert isinstance(sub_dependant, EndpointDependant) self.dependant.endpoint_dependencies.insert(0, sub_dependant) - elif depends.dependency_scope == "lifespan": + else: assert isinstance(sub_dependant, LifespanDependant) self.dependant.lifespan_dependencies.insert(0, sub_dependant) - else: - assert_never(depends.dependency_scope) self._flat_dependant = get_flat_dependant(self.dependant) self._embed_body_fields = _should_embed_body_fields( @@ -572,14 +570,12 @@ class APIRoute(routing.Route): sub_dependant = get_parameterless_sub_dependant( depends=depends, path=self.path_format, caller=self.__call__, index=i ) - if depends.dependency_scope == "endpoint": - assert isinstance(sub_dependant, EndpointDependant) + if isinstance(sub_dependant, EndpointDependant): self.dependant.endpoint_dependencies.insert(0, sub_dependant) - elif depends.dependency_scope == "lifespan": + else: assert isinstance(sub_dependant, LifespanDependant) self.dependant.lifespan_dependencies.insert(0, sub_dependant) - else: - assert_never(depends.dependency_scope) + self._flat_dependant = get_flat_dependant(self.dependant) self._embed_body_fields = _should_embed_body_fields( self._flat_dependant.body_params diff --git a/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py index 9c7800ee2..eb61ed248 100644 --- a/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py +++ b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py @@ -482,10 +482,10 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_paramete annotation, is_websocket ): async def dependency_func() -> None: - yield + yield # pragma: nocover async def override_dependency_func(param: annotation) -> None: - yield + yield # pragma: nocover app = FastAPI() app.dependency_overrides[dependency_func] = override_dependency_func @@ -551,15 +551,15 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen depends_class, is_websocket ): async def sub_dependency() -> None: - pass + pass # pragma: nocover async def dependency_func() -> None: - yield + yield # pragma: nocover async def override_dependency_func( param: Annotated[None, depends_class(sub_dependency)], ) -> None: - yield + yield # pragma: nocover app = FastAPI() diff --git a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py index a6fdea9e4..65c3c2dc1 100644 --- a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py +++ b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py @@ -4,6 +4,8 @@ from time import sleep from typing import Any, AsyncGenerator, Dict, List, Tuple import pytest +from setuptools import depends + from fastapi import ( APIRouter, BackgroundTasks, @@ -19,6 +21,7 @@ from fastapi import ( Request, WebSocket, ) +from fastapi.dependencies.utils import get_endpoint_dependant from fastapi.exceptions import ( DependencyScopeConflict, InvalidDependencyScope, @@ -443,7 +446,7 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( annotation, is_websocket ): async def dependency_func(param: annotation) -> None: - yield + yield # pragma: nocover app = FastAPI() @@ -598,7 +601,8 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( async def shutdown() -> None: nonlocal lifespan_ended lifespan_ended = True - elif lifespan_style == "events_constructor": + else: + assert lifespan_style == "events_constructor" async def startup() -> None: nonlocal lifespan_started @@ -609,8 +613,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( lifespan_ended = True app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) - else: - assert_never(lifespan_style) + dependency_factory = DependencyFactory(dependency_style) @@ -641,12 +644,12 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( depends_class, is_websocket ): async def sub_dependency() -> None: - pass + pass # pragma: nocover async def dependency_func( param: Annotated[None, depends_class(sub_dependency)], ) -> None: - yield + pass # pragma: nocover app = FastAPI() @@ -730,6 +733,39 @@ def test_endpoints_report_incorrect_dependency_scope( ) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_endpoints_report_incorrect_dependency_scope_at_router_scope( + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket +): + dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan" + ) + + # We intentionally change the dependency scope here to bypass the + # validation at the function level. + depends.dependency_scope = "asdad" + + if routing_style == "app_endpoint": + app = FastAPI(dependencies=[depends]) + router = app + else: + router = APIRouter(dependencies=[depends]) + + + with pytest.raises(InvalidDependencyScope): + create_endpoint_0_annotations( + router=router, + path="/test", + is_websocket=is_websocket, + ) + + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @@ -866,3 +902,26 @@ def test_bad_lifespan_scoped_dependencies( pass assert exception_info.value.args == (1,) + +def test_endpoint_dependant_backwards_compatibility(): + dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) + + def endpoint( + dependency1: Annotated[int, Depends(dependency_factory.get_dependency())], + dependency2: Annotated[int, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan" + )], + ): + pass # pragma: nocover + + dependant = get_endpoint_dependant( + path="/test", + call=endpoint, + name="endpoint", + ) + + assert dependant.dependencies == tuple( + dependant.lifespan_dependencies + + dependant.endpoint_dependencies + ) \ No newline at end of file diff --git a/tests/test_lifespan_scoped_dependencies/testing_utilities.py b/tests/test_lifespan_scoped_dependencies/testing_utilities.py index 77054f941..88e0925bc 100644 --- a/tests/test_lifespan_scoped_dependencies/testing_utilities.py +++ b/tests/test_lifespan_scoped_dependencies/testing_utilities.py @@ -1,10 +1,8 @@ -import threading from enum import Enum from typing import Any, AsyncGenerator, Generator, List, TypeVar, Union from fastapi import APIRouter, FastAPI, WebSocket -from starlette.testclient import TestClient -from starlette.websockets import WebSocketDisconnect +from fastapi.testclient import TestClient from typing_extensions import assert_never T = TypeVar("T") @@ -34,7 +32,6 @@ class DependencyFactory: self.dependency_style = dependency_style self._should_error = should_error self._value_offset = value_offset - self._event = threading.Event() def get_dependency(self): if self.dependency_style == DependencyStyle.SYNC_FUNCTION: @@ -49,7 +46,7 @@ class DependencyFactory: if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: return self._asynchronous_generator_dependency - assert_never(self.dependency_style) + assert_never(self.dependency_style) # pragma: nocover async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: self.activation_times += 1 @@ -58,7 +55,6 @@ class DependencyFactory: yield self.activation_times + self._value_offset self.deactivation_times += 1 - self._event.set() def _synchronous_generator_dependency(self) -> Generator[T, None, None]: self.activation_times += 1 @@ -67,7 +63,6 @@ class DependencyFactory: yield self.activation_times + self._value_offset self.deactivation_times += 1 - self._event.set() async def _asynchronous_function_dependency(self) -> T: self.activation_times += 1 @@ -106,10 +101,7 @@ def create_endpoint_0_annotations( @router.websocket(path) async def endpoint(websocket: WebSocket) -> None: await websocket.accept() - try: - await websocket.send_json(None) - except WebSocketDisconnect: - pass + await websocket.send_json(None) else: @router.post(path) @@ -133,10 +125,7 @@ def create_endpoint_1_annotation( assert value == expected_value await websocket.accept() - try: - await websocket.send_json(value) - except WebSocketDisconnect: - pass + await websocket.send_json(value) else: @router.post(path) @@ -164,10 +153,7 @@ def create_endpoint_2_annotations( value2: annotation2, ) -> None: await websocket.accept() - try: - await websocket.send_json([value1, value2]) - except WebSocketDisconnect: - await websocket.close() + await websocket.send_json([value1, value2]) else: @router.post(path) @@ -197,10 +183,7 @@ def create_endpoint_3_annotations( value3: annotation3, ) -> None: await websocket.accept() - try: - await websocket.send_json([value1, value2, value3]) - except WebSocketDisconnect: - await websocket.close() + await websocket.send_json([value1, value2, value3]) else: @router.post(path) From f9ca77e1ec851871aa32b491b31220766c93fead Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 9 Nov 2024 10:33:09 +0000 Subject: [PATCH 13/29] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/dependencies/utils.py | 2 +- .../test_endpoint_usage.py | 25 +++++++------------ 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 94e983e1e..bd1a0dc50 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -83,7 +83,7 @@ from starlette.datastructures import ( from starlette.requests import HTTPConnection, Request from starlette.responses import Response from starlette.websockets import WebSocket -from typing_extensions import Annotated, assert_never, get_args, get_origin +from typing_extensions import Annotated, get_args, get_origin multipart_not_installed_error = ( 'Form data requires "python-multipart" to be installed. \n' diff --git a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py index 65c3c2dc1..7e6c80d0e 100644 --- a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py +++ b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py @@ -4,8 +4,6 @@ from time import sleep from typing import Any, AsyncGenerator, Dict, List, Tuple import pytest -from setuptools import depends - from fastapi import ( APIRouter, BackgroundTasks, @@ -30,7 +28,7 @@ from fastapi.exceptions import ( from fastapi.params import Security from fastapi.security import SecurityScopes from fastapi.testclient import TestClient -from typing_extensions import Annotated, Literal, assert_never +from typing_extensions import Annotated, Literal from tests.test_lifespan_scoped_dependencies.testing_utilities import ( DependencyFactory, @@ -614,7 +612,6 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) - dependency_factory = DependencyFactory(dependency_style) create_endpoint_1_annotation( @@ -742,10 +739,7 @@ def test_endpoints_report_incorrect_dependency_scope_at_router_scope( ): dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - ) + depends = Depends(dependency_factory.get_dependency(), dependency_scope="lifespan") # We intentionally change the dependency scope here to bypass the # validation at the function level. @@ -757,7 +751,6 @@ def test_endpoints_report_incorrect_dependency_scope_at_router_scope( else: router = APIRouter(dependencies=[depends]) - with pytest.raises(InvalidDependencyScope): create_endpoint_0_annotations( router=router, @@ -903,15 +896,16 @@ def test_bad_lifespan_scoped_dependencies( assert exception_info.value.args == (1,) + def test_endpoint_dependant_backwards_compatibility(): dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) def endpoint( dependency1: Annotated[int, Depends(dependency_factory.get_dependency())], - dependency2: Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - )], + dependency2: Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], ): pass # pragma: nocover @@ -922,6 +916,5 @@ def test_endpoint_dependant_backwards_compatibility(): ) assert dependant.dependencies == tuple( - dependant.lifespan_dependencies + - dependant.endpoint_dependencies - ) \ No newline at end of file + dependant.lifespan_dependencies + dependant.endpoint_dependencies + ) From 8a5c3fe56bc4453bd3b86fc0bccf04b5b2dab5a2 Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 9 Nov 2024 12:34:18 +0200 Subject: [PATCH 14/29] Fixed more linting --- fastapi/dependencies/utils.py | 2 +- .../test_endpoint_usage.py | 25 +++++++------------ 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 94e983e1e..bd1a0dc50 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -83,7 +83,7 @@ from starlette.datastructures import ( from starlette.requests import HTTPConnection, Request from starlette.responses import Response from starlette.websockets import WebSocket -from typing_extensions import Annotated, assert_never, get_args, get_origin +from typing_extensions import Annotated, get_args, get_origin multipart_not_installed_error = ( 'Form data requires "python-multipart" to be installed. \n' diff --git a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py index 65c3c2dc1..7e6c80d0e 100644 --- a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py +++ b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py @@ -4,8 +4,6 @@ from time import sleep from typing import Any, AsyncGenerator, Dict, List, Tuple import pytest -from setuptools import depends - from fastapi import ( APIRouter, BackgroundTasks, @@ -30,7 +28,7 @@ from fastapi.exceptions import ( from fastapi.params import Security from fastapi.security import SecurityScopes from fastapi.testclient import TestClient -from typing_extensions import Annotated, Literal, assert_never +from typing_extensions import Annotated, Literal from tests.test_lifespan_scoped_dependencies.testing_utilities import ( DependencyFactory, @@ -614,7 +612,6 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) - dependency_factory = DependencyFactory(dependency_style) create_endpoint_1_annotation( @@ -742,10 +739,7 @@ def test_endpoints_report_incorrect_dependency_scope_at_router_scope( ): dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - ) + depends = Depends(dependency_factory.get_dependency(), dependency_scope="lifespan") # We intentionally change the dependency scope here to bypass the # validation at the function level. @@ -757,7 +751,6 @@ def test_endpoints_report_incorrect_dependency_scope_at_router_scope( else: router = APIRouter(dependencies=[depends]) - with pytest.raises(InvalidDependencyScope): create_endpoint_0_annotations( router=router, @@ -903,15 +896,16 @@ def test_bad_lifespan_scoped_dependencies( assert exception_info.value.args == (1,) + def test_endpoint_dependant_backwards_compatibility(): dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) def endpoint( dependency1: Annotated[int, Depends(dependency_factory.get_dependency())], - dependency2: Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - )], + dependency2: Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], ): pass # pragma: nocover @@ -922,6 +916,5 @@ def test_endpoint_dependant_backwards_compatibility(): ) assert dependant.dependencies == tuple( - dependant.lifespan_dependencies + - dependant.endpoint_dependencies - ) \ No newline at end of file + dependant.lifespan_dependencies + dependant.endpoint_dependencies + ) From 33f766ab9a715a1541a021c02d18d121876a263e Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 9 Nov 2024 14:32:55 +0200 Subject: [PATCH 15/29] Yet another coverage fixes --- .../test_endpoint_usage.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py index 7e6c80d0e..0770c07fb 100644 --- a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py +++ b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py @@ -745,7 +745,7 @@ def test_endpoints_report_incorrect_dependency_scope_at_router_scope( # validation at the function level. depends.dependency_scope = "asdad" - if routing_style == "app_endpoint": + if routing_style == "app": app = FastAPI(dependencies=[depends]) router = app else: @@ -800,7 +800,7 @@ def test_endpoints_report_uninitialized_dependency( with pytest.raises(UninitializedLifespanDependency): if is_websocket: with client.websocket_connect("/test"): - pass + pass # pragma: nocover else: client.post("/test") finally: @@ -850,7 +850,7 @@ def test_endpoints_report_uninitialized_internal_lifespan( with pytest.raises(UninitializedLifespanDependency): if is_websocket: with client.websocket_connect("/test"): - pass + pass # pragma: nocover else: client.post("/test") finally: From 01dc24e77cd2ace6e0b8e37634dc00c2fcc1a668 Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 23 Nov 2024 21:10:47 +0200 Subject: [PATCH 16/29] Added documentation for lifespan scoped dependencies --- .../lifespan-scoped-dependencies.md | 99 ++++++++++++++ docs_src/dependencies/tutorial013a.py | 39 ++++++ docs_src/dependencies/tutorial013a_an_py39.py | 38 ++++++ docs_src/dependencies/tutorial013b.py | 54 ++++++++ docs_src/dependencies/tutorial013b_an_py39.py | 53 +++++++ docs_src/dependencies/tutorial013c.py | 47 +++++++ docs_src/dependencies/tutorial013c_an_py39.py | 52 +++++++ docs_src/dependencies/tutorial013d.py | 46 +++++++ docs_src/dependencies/tutorial013d_an_py39.py | 45 ++++++ .../test_dependencies/test_tutorial013a.py | 69 ++++++++++ .../test_tutorial013a_an_py39.py | 69 ++++++++++ .../test_dependencies/test_tutorial013b.py | 129 ++++++++++++++++++ .../test_tutorial013b_an_py39.py | 129 ++++++++++++++++++ .../test_dependencies/test_tutorial013c.py | 82 +++++++++++ .../test_tutorial013c_an_py39.py | 82 +++++++++++ .../test_dependencies/test_tutorial013d.py | 80 +++++++++++ .../test_tutorial013d_an_py39.py | 80 +++++++++++ 17 files changed, 1193 insertions(+) create mode 100644 docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md create mode 100644 docs_src/dependencies/tutorial013a.py create mode 100644 docs_src/dependencies/tutorial013a_an_py39.py create mode 100644 docs_src/dependencies/tutorial013b.py create mode 100644 docs_src/dependencies/tutorial013b_an_py39.py create mode 100644 docs_src/dependencies/tutorial013c.py create mode 100644 docs_src/dependencies/tutorial013c_an_py39.py create mode 100644 docs_src/dependencies/tutorial013d.py create mode 100644 docs_src/dependencies/tutorial013d_an_py39.py create mode 100644 tests/test_tutorial/test_dependencies/test_tutorial013a.py create mode 100644 tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py create mode 100644 tests/test_tutorial/test_dependencies/test_tutorial013b.py create mode 100644 tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py create mode 100644 tests/test_tutorial/test_dependencies/test_tutorial013c.py create mode 100644 tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py create mode 100644 tests/test_tutorial/test_dependencies/test_tutorial013d.py create mode 100644 tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py diff --git a/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md new file mode 100644 index 000000000..d0aea0a16 --- /dev/null +++ b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md @@ -0,0 +1,99 @@ +# Lifespan Scoped Dependencies + +So far we've used dependencies which are "endpoint scoped". Meaning, they are +called again and again for every incoming request to the endpoint. However, +this is not ideal for all kinds of dependencies. + +Sometimes dependencies have a large setup/teardown time, or there is a need +for their value to be shared throughout the lifespan of the application. An +example of this would be a connection to a database. Databases are typically +less efficient when working with lots of connections and would prefer that +clients would create a single connection for their operations. + +For such cases, you might want to use "lifespan scoped" dependencies. + +## Intro + +Lifespan scoped dependencies work similarly to the dependencies we've worked +with so far (which are endpoint scoped). However, they are called once and only +once in the application's lifespan (instead of being called again and again for +every request). The returned value will be shared across all requests that need +it. + + +## Create a lifespan scoped dependency + +You may declare a dependency as a lifespan scoped dependency by passing +`dependency_scope="lifespan"` to the `Depends` function: + +{* ../../docs_src/dependencies/tutorial013a_an_py39.py hl[16] *} + +/// tip + +In the example above we saved the annotation to a separate variable, and then +reused it in our endpoints. This is not a requirement, we could also declare +the exact same annotation in both endpoints. However, it is recommended that you +do save the annotation to a variable so you won't accidentally forget to pass +`dependency_scope="lifespan"` to some of the endpoints (Causing the endpoint +to create a new database connection for every request). + +/// + +In this example, the `get_database_connection` dependency will be executed once, +during the application's startup. **FastAPI** will internally save the resulting +connection object, and whenever the `read_users` and `read_items` endpoints are +called, they will be using the previously saved connection. Once the application +shuts down, **FastAPI** will make sure to gracefully close the connection object. + +## The `use_cache` argument + +The `use_cache` argument works similarly to the way it worked with endpoint +scoped dependencies. Meaning as **FastAPI** gathers lifespan scoped dependencies, it +will cache dependencies it already encountered before. However, you can disable +this behavior by passing `use_cache=False` to `Depends`: + +{* ../../docs_src/dependencies/tutorial013b_an_py39.py hl[16] *} + +In this example, the `read_users` and `read_groups` endpoints are using +`use_cache=False` whereas the `read_items` and `read_item` are using +`use_cache=True`. That means that we'll have a total of 3 connections created +for the duration of the application's lifespan. One connection will be shared +across all requests for the `read_items` and `read_item` endpoints. A second +connection will be shared across all requests for the `read_users` endpoint. The +third and final connection will be shared across all requests for the +`read_groups` endpoint. + + +## Lifespan Scoped Sub-Dependencies +Just like with endpoint scoped dependencies, lifespan scoped dependencies may +use other lifespan scoped sub-dependencies themselves: + +{* ../../docs_src/dependencies/tutorial013c_an_py39.py hl[16] *} + +Endpoint scoped dependencies may use lifespan scoped sub dependencies as well: + +{* ../../docs_src/dependencies/tutorial013d_an_py39.py hl[16] *} + +/// note + +You can pass `dependency_scope="endpoint"` if you wish to explicitly specify +that a dependency is endpoint scoped. It will work the same as not specifying +a dependency scope at all. + +/// + +As you can see, regardless of the scope, dependencies can use lifespan scoped +sub-dependencies. + +## Dependency Scope Conflicts +By definition, lifespan scoped dependencies are being setup in the application's +startup process, before any request is ever being made to any endpoint. +Therefore, it is not possible for a lifespan scoped dependency to use any +parameters that require the scope of an endpoint. + +That includes but not limited to: + * Parts of the request (like `Body`, `Query` and `Path`) + * The request/response objects themselves (like `Request`, `Response` and `WebSocket`) + * Endpoint scoped sub-dependencies. + +Defining a dependency with such parameters will raise an `InvalidDependencyScope` error. diff --git a/docs_src/dependencies/tutorial013a.py b/docs_src/dependencies/tutorial013a.py new file mode 100644 index 000000000..e014289d2 --- /dev/null +++ b/docs_src/dependencies/tutorial013a.py @@ -0,0 +1,39 @@ +from typing import List + +from fastapi import Depends, FastAPI +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> List[dict]: + pass + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") + + +@app.get("/users/") +async def read_users(database_connection: MyDatabaseConnection = GlobalDatabaseConnection): + return await database_connection.get_records("users") + + +@app.get("/items/") +async def read_items(database_connection: MyDatabaseConnection = GlobalDatabaseConnection): + return await database_connection.get_records("items") diff --git a/docs_src/dependencies/tutorial013a_an_py39.py b/docs_src/dependencies/tutorial013a_an_py39.py new file mode 100644 index 000000000..c2e8c6672 --- /dev/null +++ b/docs_src/dependencies/tutorial013a_an_py39.py @@ -0,0 +1,38 @@ +from typing import Annotated + +from fastapi import Depends, FastAPI +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> list[dict]: + pass + + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")] + +@app.get("/users/") +async def read_users(database_connection: GlobalDatabaseConnection): + return await database_connection.get_records("users") + + +@app.get("/items/") +async def read_items(database_connection: GlobalDatabaseConnection): + return await database_connection.get_records("items") diff --git a/docs_src/dependencies/tutorial013b.py b/docs_src/dependencies/tutorial013b.py new file mode 100644 index 000000000..0b9907de4 --- /dev/null +++ b/docs_src/dependencies/tutorial013b.py @@ -0,0 +1,54 @@ +from typing import List + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> List[dict]: + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + + +GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") +DedicatedDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan", use_cache=False) + +@app.get("/groups/") +async def read_groups(database_connection: MyDatabaseConnection = DedicatedDatabaseConnection): + return await database_connection.get_records("groups") + +@app.get("/users/") +async def read_users(database_connection: MyDatabaseConnection = DedicatedDatabaseConnection): + return await database_connection.get_records("users") + + +@app.get("/items/") +async def read_items(database_connection: MyDatabaseConnection = GlobalDatabaseConnection): + return await database_connection.get_records("items") + +@app.get("/items/{item_id}") +async def read_item( + item_id: str = Path(), + database_connection: MyDatabaseConnection = GlobalDatabaseConnection +): + return await database_connection.get_record("items", item_id) diff --git a/docs_src/dependencies/tutorial013b_an_py39.py b/docs_src/dependencies/tutorial013b_an_py39.py new file mode 100644 index 000000000..c5274417d --- /dev/null +++ b/docs_src/dependencies/tutorial013b_an_py39.py @@ -0,0 +1,53 @@ +from typing import Annotated + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> list[dict]: + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")] +DedicatedDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan", use_cache=False)] + +@app.get("/groups/") +async def read_groups(database_connection: DedicatedDatabaseConnection): + return await database_connection.get_records("groups") + +@app.get("/users/") +async def read_users(database_connection: DedicatedDatabaseConnection): + return await database_connection.get_records("users") + + +@app.get("/items/") +async def read_items(database_connection: GlobalDatabaseConnection): + return await database_connection.get_records("items") + +@app.get("/items/{item_id}") +async def read_item( + database_connection: GlobalDatabaseConnection, + item_id: Annotated[str, Path()] +): + return await database_connection.get_record("items", item_id) diff --git a/docs_src/dependencies/tutorial013c.py b/docs_src/dependencies/tutorial013c.py new file mode 100644 index 000000000..c4eb99f25 --- /dev/null +++ b/docs_src/dependencies/tutorial013c.py @@ -0,0 +1,47 @@ +from dataclasses import dataclass + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +@dataclass +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + connection_string: str + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + +app = FastAPI() + + +async def get_configuration() -> dict: + return { + "database_url": "sqlite:///database.db", + } + +GlobalConfiguration = Depends(get_configuration, dependency_scope="lifespan") + + +async def get_database_connection( + configuration: dict = GlobalConfiguration +): + async with MyDatabaseConnection(configuration["database_url"]) as connection: + yield connection + +GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") + +@app.get("/users/{user_id}") +async def read_user( + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, + user_id: str = Path() +): + return await database_connection.get_record("users", user_id) diff --git a/docs_src/dependencies/tutorial013c_an_py39.py b/docs_src/dependencies/tutorial013c_an_py39.py new file mode 100644 index 000000000..1830a4b4e --- /dev/null +++ b/docs_src/dependencies/tutorial013c_an_py39.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import Annotated, List + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +@dataclass +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + connection_string: str + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> List[dict]: + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + + +app = FastAPI() + + +async def get_configuration() -> dict: + return { + "database_url": "sqlite:///database.db", + } + +GlobalConfiguration = Annotated[dict, Depends(get_configuration, dependency_scope="lifespan")] + + +async def get_database_connection(configuration: GlobalConfiguration): + async with MyDatabaseConnection( + configuration["database_url"]) as connection: + yield connection + +GlobalDatabaseConnection = Annotated[get_database_connection, Depends(get_database_connection, dependency_scope="lifespan")] + + +@app.get("/users/{user_id}") +async def read_user( + database_connection: GlobalDatabaseConnection, + user_id: Annotated[str, Path()] +): + return await database_connection.get_record("users", user_id) diff --git a/docs_src/dependencies/tutorial013d.py b/docs_src/dependencies/tutorial013d.py new file mode 100644 index 000000000..dc6441620 --- /dev/null +++ b/docs_src/dependencies/tutorial013d.py @@ -0,0 +1,46 @@ +from typing import List + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> List[dict]: + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") + + +async def get_user_record( + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, + user_id: str = Path() +) -> dict: + return await database_connection.get_record("users", user_id) + + +@app.get("/users/{user_id}") +async def read_user( + user_record: dict = Depends(get_user_record) +): + return user_record diff --git a/docs_src/dependencies/tutorial013d_an_py39.py b/docs_src/dependencies/tutorial013d_an_py39.py new file mode 100644 index 000000000..41aae4c47 --- /dev/null +++ b/docs_src/dependencies/tutorial013d_an_py39.py @@ -0,0 +1,45 @@ +from typing import Annotated + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> list[dict]: + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")] + + +async def get_user_record( + database_connection: GlobalDatabaseConnection, + user_id: Annotated[str, Path()] +) -> dict: + return await database_connection.get_record("users", user_id) + +@app.get("/users/{user_id}") +async def read_user( + user_record: Annotated[dict, Depends(get_user_record)] +): + return user_record diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013a.py b/tests/test_tutorial/test_dependencies/test_tutorial013a.py new file mode 100644 index 000000000..d8a6884c0 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013a.py @@ -0,0 +1,69 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013a import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_records_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_records(self, table_name: str) -> List[dict]: + self.get_records_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return [] + + +@pytest.fixture +def database_connection_mock(monkeypatch) -> MockDatabaseConnection: + mock = MockDatabaseConnection() + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + lambda *args, **kwargs: mock + ) + + return mock + + +def test_dependency_usage(database_connection_mock): + assert database_connection_mock.enter_count == 0 + assert database_connection_mock.exit_count == 0 + with TestClient(app) as test_client: + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + + response = test_client.get('/users') + assert response.status_code == 200 + assert response.json() == [] + + assert database_connection_mock.get_records_count == 1 + + response = test_client.get('/items') + assert response.status_code == 200 + assert response.json() == [] + + assert database_connection_mock.get_records_count == 2 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py new file mode 100644 index 000000000..05af6b665 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py @@ -0,0 +1,69 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013a_an_py39 import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_records_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_records(self, table_name: str) -> List[dict]: + self.get_records_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return [] + + +@pytest.fixture +def database_connection_mock(monkeypatch) -> MockDatabaseConnection: + mock = MockDatabaseConnection() + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + lambda *args, **kwargs: mock + ) + + return mock + + +def test_dependency_usage(database_connection_mock): + assert database_connection_mock.enter_count == 0 + assert database_connection_mock.exit_count == 0 + with TestClient(app) as test_client: + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + + response = test_client.get('/users') + assert response.status_code == 200 + assert response.json() == [] + + assert database_connection_mock.get_records_count == 1 + + response = test_client.get('/items') + assert response.status_code == 200 + assert response.json() == [] + + assert database_connection_mock.get_records_count == 2 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013b.py b/tests/test_tutorial/test_dependencies/test_tutorial013b.py new file mode 100644 index 000000000..88942d0f3 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013b.py @@ -0,0 +1,129 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013b import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_records_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_records(self, table_name: str) -> List[dict]: + self.get_records_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return [] + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return { + "table_name": table_name, + "record_id": record_id, + } + + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + def _get_new_connection_mock(*args, **kwargs): + mock = MockDatabaseConnection() + connections.append(mock) + + return mock + + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + _get_new_connection_mock + ) + return connections + + +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 3 + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 0 + assert connection.get_records_count == 0 + assert connection.get_record_count == 0 + + response = test_client.get('/users') + assert response.status_code == 200 + assert response.json() == [] + + users_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 1: + users_connection = connection + break + + assert users_connection is not None, "No connection was found for users endpoint" + + response = test_client.get('/groups') + assert response.status_code == 200 + assert response.json() == [] + + groups_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 1 and connection is not users_connection: + groups_connection = connection + break + + assert groups_connection is not None, "No connection was found for groups endpoint" + assert groups_connection.get_records_count == 1 + + items_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 0: + items_connection = connection + break + + assert items_connection is not None, "No connection was found for items endpoint" + + response = test_client.get('/items') + assert response.status_code == 200 + assert response.json() == [] + + assert items_connection.get_records_count == 1 + assert items_connection.get_record_count == 0 + + response = test_client.get('/items/asd') + assert response.status_code == 200 + assert response.json() == { + "table_name": "items", + "record_id": "asd", + } + + assert items_connection.get_records_count == 1 + assert items_connection.get_record_count == 1 + + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 0 + + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py new file mode 100644 index 000000000..6c0c2132f --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py @@ -0,0 +1,129 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013b_an_py39 import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_records_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_records(self, table_name: str) -> List[dict]: + self.get_records_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return [] + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return { + "table_name": table_name, + "record_id": record_id, + } + + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + def _get_new_connection_mock(*args, **kwargs): + mock = MockDatabaseConnection() + connections.append(mock) + + return mock + + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + _get_new_connection_mock + ) + return connections + + +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 3 + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 0 + assert connection.get_records_count == 0 + assert connection.get_record_count == 0 + + response = test_client.get('/users') + assert response.status_code == 200 + assert response.json() == [] + + users_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 1: + users_connection = connection + break + + assert users_connection is not None, "No connection was found for users endpoint" + + response = test_client.get('/groups') + assert response.status_code == 200 + assert response.json() == [] + + groups_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 1 and connection is not users_connection: + groups_connection = connection + break + + assert groups_connection is not None, "No connection was found for groups endpoint" + assert groups_connection.get_records_count == 1 + + items_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 0: + items_connection = connection + break + + assert items_connection is not None, "No connection was found for items endpoint" + + response = test_client.get('/items') + assert response.status_code == 200 + assert response.json() == [] + + assert items_connection.get_records_count == 1 + assert items_connection.get_record_count == 0 + + response = test_client.get('/items/asd') + assert response.status_code == 200 + assert response.json() == { + "table_name": "items", + "record_id": "asd", + } + + assert items_connection.get_records_count == 1 + assert items_connection.get_record_count == 1 + + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 0 + + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013c.py b/tests/test_tutorial/test_dependencies/test_tutorial013c.py new file mode 100644 index 000000000..09390cd99 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013c.py @@ -0,0 +1,82 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013c import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self, url: str): + self.url = url + self.enter_count = 0 + self.exit_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_record(self, table_name, record_id) + return { + "table_name": table_name, + "record_id": record_id, + } + + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + def _get_new_connection_mock(cls, url): + mock = MockDatabaseConnection(url) + connections.append(mock) + + return mock + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + _get_new_connection_mock + ) + return connections + + +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 1 + [database_connection_mock] = database_connection_mocks + + assert database_connection_mock.url == "sqlite:///database.db" + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 0 + + response = test_client.get('/users/user') + assert response.status_code == 200 + assert response.json() == { + "table_name": "users", + "record_id": "user", + } + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 1 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 + assert database_connection_mock.get_record_count == 1 + + assert len(database_connection_mocks) == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py new file mode 100644 index 000000000..03f01145a --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py @@ -0,0 +1,82 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013c_an_py39 import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self, url: str): + self.url = url + self.enter_count = 0 + self.exit_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_record(self, table_name, record_id) + return { + "table_name": table_name, + "record_id": record_id, + } + + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + def _get_new_connection_mock(cls, url): + mock = MockDatabaseConnection(url) + connections.append(mock) + + return mock + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + _get_new_connection_mock + ) + return connections + + +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 1 + [database_connection_mock] = database_connection_mocks + + assert database_connection_mock.url == "sqlite:///database.db" + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 0 + + response = test_client.get('/users/user') + assert response.status_code == 200 + assert response.json() == { + "table_name": "users", + "record_id": "user", + } + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 1 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 + assert database_connection_mock.get_record_count == 1 + + assert len(database_connection_mocks) == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013d.py b/tests/test_tutorial/test_dependencies/test_tutorial013d.py new file mode 100644 index 000000000..dccf6a006 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013d.py @@ -0,0 +1,80 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013d import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_record(self, table_name, record_id) + return { + "table_name": table_name, + "record_id": record_id, + } + + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + def _get_new_connection_mock(*args, **kwargs): + mock = MockDatabaseConnection() + connections.append(mock) + + return mock + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + _get_new_connection_mock + ) + return connections + + +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 1 + [database_connection_mock] = database_connection_mocks + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 0 + + response = test_client.get('/users/user') + assert response.status_code == 200 + assert response.json() == { + "table_name": "users", + "record_id": "user", + } + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 1 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 + assert database_connection_mock.get_record_count == 1 + + assert len(database_connection_mocks) == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py new file mode 100644 index 000000000..4f27fccca --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py @@ -0,0 +1,80 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013d_an_py39 import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_record(self, table_name, record_id) + return { + "table_name": table_name, + "record_id": record_id, + } + + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + def _get_new_connection_mock(*args, **kwargs): + mock = MockDatabaseConnection() + connections.append(mock) + + return mock + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + _get_new_connection_mock + ) + return connections + + +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 1 + [database_connection_mock] = database_connection_mocks + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 0 + + response = test_client.get('/users/user') + assert response.status_code == 200 + assert response.json() == { + "table_name": "users", + "record_id": "user", + } + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 1 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 + assert database_connection_mock.get_record_count == 1 + + assert len(database_connection_mocks) == 1 From c7d4d10b753d2c49de7d1e05f139ac21c9428bd1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 23 Nov 2024 19:12:38 +0000 Subject: [PATCH 17/29] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../lifespan-scoped-dependencies.md | 52 +++++++++---------- docs_src/dependencies/tutorial013a.py | 9 +++- docs_src/dependencies/tutorial013a_an_py39.py | 6 ++- docs_src/dependencies/tutorial013b.py | 23 +++++--- docs_src/dependencies/tutorial013b_an_py39.py | 16 ++++-- docs_src/dependencies/tutorial013c.py | 13 +++-- docs_src/dependencies/tutorial013c_an_py39.py | 18 ++++--- docs_src/dependencies/tutorial013d.py | 9 ++-- docs_src/dependencies/tutorial013d_an_py39.py | 13 ++--- .../test_dependencies/test_tutorial013a.py | 10 ++-- .../test_tutorial013a_an_py39.py | 10 ++-- .../test_dependencies/test_tutorial013b.py | 29 ++++++----- .../test_tutorial013b_an_py39.py | 29 ++++++----- .../test_dependencies/test_tutorial013c.py | 10 ++-- .../test_tutorial013c_an_py39.py | 10 ++-- .../test_dependencies/test_tutorial013d.py | 10 ++-- .../test_tutorial013d_an_py39.py | 10 ++-- 17 files changed, 146 insertions(+), 131 deletions(-) diff --git a/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md index d0aea0a16..2a3585cf9 100644 --- a/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md +++ b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md @@ -4,68 +4,68 @@ So far we've used dependencies which are "endpoint scoped". Meaning, they are called again and again for every incoming request to the endpoint. However, this is not ideal for all kinds of dependencies. -Sometimes dependencies have a large setup/teardown time, or there is a need -for their value to be shared throughout the lifespan of the application. An -example of this would be a connection to a database. Databases are typically -less efficient when working with lots of connections and would prefer that -clients would create a single connection for their operations. +Sometimes dependencies have a large setup/teardown time, or there is a need +for their value to be shared throughout the lifespan of the application. An +example of this would be a connection to a database. Databases are typically +less efficient when working with lots of connections and would prefer that +clients would create a single connection for their operations. For such cases, you might want to use "lifespan scoped" dependencies. ## Intro -Lifespan scoped dependencies work similarly to the dependencies we've worked +Lifespan scoped dependencies work similarly to the dependencies we've worked with so far (which are endpoint scoped). However, they are called once and only -once in the application's lifespan (instead of being called again and again for -every request). The returned value will be shared across all requests that need +once in the application's lifespan (instead of being called again and again for +every request). The returned value will be shared across all requests that need it. ## Create a lifespan scoped dependency -You may declare a dependency as a lifespan scoped dependency by passing +You may declare a dependency as a lifespan scoped dependency by passing `dependency_scope="lifespan"` to the `Depends` function: {* ../../docs_src/dependencies/tutorial013a_an_py39.py hl[16] *} /// tip -In the example above we saved the annotation to a separate variable, and then -reused it in our endpoints. This is not a requirement, we could also declare -the exact same annotation in both endpoints. However, it is recommended that you -do save the annotation to a variable so you won't accidentally forget to pass +In the example above we saved the annotation to a separate variable, and then +reused it in our endpoints. This is not a requirement, we could also declare +the exact same annotation in both endpoints. However, it is recommended that you +do save the annotation to a variable so you won't accidentally forget to pass `dependency_scope="lifespan"` to some of the endpoints (Causing the endpoint to create a new database connection for every request). /// In this example, the `get_database_connection` dependency will be executed once, -during the application's startup. **FastAPI** will internally save the resulting -connection object, and whenever the `read_users` and `read_items` endpoints are +during the application's startup. **FastAPI** will internally save the resulting +connection object, and whenever the `read_users` and `read_items` endpoints are called, they will be using the previously saved connection. Once the application shuts down, **FastAPI** will make sure to gracefully close the connection object. ## The `use_cache` argument -The `use_cache` argument works similarly to the way it worked with endpoint +The `use_cache` argument works similarly to the way it worked with endpoint scoped dependencies. Meaning as **FastAPI** gathers lifespan scoped dependencies, it will cache dependencies it already encountered before. However, you can disable this behavior by passing `use_cache=False` to `Depends`: {* ../../docs_src/dependencies/tutorial013b_an_py39.py hl[16] *} -In this example, the `read_users` and `read_groups` endpoints are using -`use_cache=False` whereas the `read_items` and `read_item` are using -`use_cache=True`. That means that we'll have a total of 3 connections created -for the duration of the application's lifespan. One connection will be shared -across all requests for the `read_items` and `read_item` endpoints. A second -connection will be shared across all requests for the `read_users` endpoint. The -third and final connection will be shared across all requests for the +In this example, the `read_users` and `read_groups` endpoints are using +`use_cache=False` whereas the `read_items` and `read_item` are using +`use_cache=True`. That means that we'll have a total of 3 connections created +for the duration of the application's lifespan. One connection will be shared +across all requests for the `read_items` and `read_item` endpoints. A second +connection will be shared across all requests for the `read_users` endpoint. The +third and final connection will be shared across all requests for the `read_groups` endpoint. ## Lifespan Scoped Sub-Dependencies -Just like with endpoint scoped dependencies, lifespan scoped dependencies may +Just like with endpoint scoped dependencies, lifespan scoped dependencies may use other lifespan scoped sub-dependencies themselves: {* ../../docs_src/dependencies/tutorial013c_an_py39.py hl[16] *} @@ -87,8 +87,8 @@ sub-dependencies. ## Dependency Scope Conflicts By definition, lifespan scoped dependencies are being setup in the application's -startup process, before any request is ever being made to any endpoint. -Therefore, it is not possible for a lifespan scoped dependency to use any +startup process, before any request is ever being made to any endpoint. +Therefore, it is not possible for a lifespan scoped dependency to use any parameters that require the scope of an endpoint. That includes but not limited to: diff --git a/docs_src/dependencies/tutorial013a.py b/docs_src/dependencies/tutorial013a.py index e014289d2..83687af24 100644 --- a/docs_src/dependencies/tutorial013a.py +++ b/docs_src/dependencies/tutorial013a.py @@ -18,6 +18,7 @@ class MyDatabaseConnection: async def get_records(self, table_name: str) -> List[dict]: pass + app = FastAPI() @@ -30,10 +31,14 @@ GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="li @app.get("/users/") -async def read_users(database_connection: MyDatabaseConnection = GlobalDatabaseConnection): +async def read_users( + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, +): return await database_connection.get_records("users") @app.get("/items/") -async def read_items(database_connection: MyDatabaseConnection = GlobalDatabaseConnection): +async def read_items( + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, +): return await database_connection.get_records("items") diff --git a/docs_src/dependencies/tutorial013a_an_py39.py b/docs_src/dependencies/tutorial013a_an_py39.py index c2e8c6672..62f10a6e1 100644 --- a/docs_src/dependencies/tutorial013a_an_py39.py +++ b/docs_src/dependencies/tutorial013a_an_py39.py @@ -8,6 +8,7 @@ class MyDatabaseConnection: """ This is a mock just for example purposes. """ + async def __aenter__(self) -> Self: return self @@ -26,7 +27,10 @@ async def get_database_connection(): yield connection -GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")] +GlobalDatabaseConnection = Annotated[ + MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan") +] + @app.get("/users/") async def read_users(database_connection: GlobalDatabaseConnection): diff --git a/docs_src/dependencies/tutorial013b.py b/docs_src/dependencies/tutorial013b.py index 0b9907de4..3123b64f5 100644 --- a/docs_src/dependencies/tutorial013b.py +++ b/docs_src/dependencies/tutorial013b.py @@ -21,6 +21,7 @@ class MyDatabaseConnection: async def get_record(self, table_name: str, record_id: str) -> dict: pass + app = FastAPI() @@ -29,26 +30,36 @@ async def get_database_connection(): yield connection - GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") -DedicatedDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan", use_cache=False) +DedicatedDatabaseConnection = Depends( + get_database_connection, dependency_scope="lifespan", use_cache=False +) + @app.get("/groups/") -async def read_groups(database_connection: MyDatabaseConnection = DedicatedDatabaseConnection): +async def read_groups( + database_connection: MyDatabaseConnection = DedicatedDatabaseConnection, +): return await database_connection.get_records("groups") + @app.get("/users/") -async def read_users(database_connection: MyDatabaseConnection = DedicatedDatabaseConnection): +async def read_users( + database_connection: MyDatabaseConnection = DedicatedDatabaseConnection, +): return await database_connection.get_records("users") @app.get("/items/") -async def read_items(database_connection: MyDatabaseConnection = GlobalDatabaseConnection): +async def read_items( + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, +): return await database_connection.get_records("items") + @app.get("/items/{item_id}") async def read_item( item_id: str = Path(), - database_connection: MyDatabaseConnection = GlobalDatabaseConnection + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, ): return await database_connection.get_record("items", item_id) diff --git a/docs_src/dependencies/tutorial013b_an_py39.py b/docs_src/dependencies/tutorial013b_an_py39.py index c5274417d..cc7205f40 100644 --- a/docs_src/dependencies/tutorial013b_an_py39.py +++ b/docs_src/dependencies/tutorial013b_an_py39.py @@ -21,6 +21,7 @@ class MyDatabaseConnection: async def get_record(self, table_name: str, record_id: str) -> dict: pass + app = FastAPI() @@ -29,13 +30,20 @@ async def get_database_connection(): yield connection -GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")] -DedicatedDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan", use_cache=False)] +GlobalDatabaseConnection = Annotated[ + MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan") +] +DedicatedDatabaseConnection = Annotated[ + MyDatabaseConnection, + Depends(get_database_connection, dependency_scope="lifespan", use_cache=False), +] + @app.get("/groups/") async def read_groups(database_connection: DedicatedDatabaseConnection): return await database_connection.get_records("groups") + @app.get("/users/") async def read_users(database_connection: DedicatedDatabaseConnection): return await database_connection.get_records("users") @@ -45,9 +53,9 @@ async def read_users(database_connection: DedicatedDatabaseConnection): async def read_items(database_connection: GlobalDatabaseConnection): return await database_connection.get_records("items") + @app.get("/items/{item_id}") async def read_item( - database_connection: GlobalDatabaseConnection, - item_id: Annotated[str, Path()] + database_connection: GlobalDatabaseConnection, item_id: Annotated[str, Path()] ): return await database_connection.get_record("items", item_id) diff --git a/docs_src/dependencies/tutorial013c.py b/docs_src/dependencies/tutorial013c.py index c4eb99f25..c8814adc0 100644 --- a/docs_src/dependencies/tutorial013c.py +++ b/docs_src/dependencies/tutorial013c.py @@ -9,6 +9,7 @@ class MyDatabaseConnection: """ This is a mock just for example purposes. """ + connection_string: str async def __aenter__(self) -> Self: @@ -20,6 +21,7 @@ class MyDatabaseConnection: async def get_record(self, table_name: str, record_id: str) -> dict: pass + app = FastAPI() @@ -28,20 +30,21 @@ async def get_configuration() -> dict: "database_url": "sqlite:///database.db", } + GlobalConfiguration = Depends(get_configuration, dependency_scope="lifespan") -async def get_database_connection( - configuration: dict = GlobalConfiguration -): +async def get_database_connection(configuration: dict = GlobalConfiguration): async with MyDatabaseConnection(configuration["database_url"]) as connection: yield connection + GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") + @app.get("/users/{user_id}") async def read_user( - database_connection: MyDatabaseConnection = GlobalDatabaseConnection, - user_id: str = Path() + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, + user_id: str = Path(), ): return await database_connection.get_record("users", user_id) diff --git a/docs_src/dependencies/tutorial013c_an_py39.py b/docs_src/dependencies/tutorial013c_an_py39.py index 1830a4b4e..a111f1b6d 100644 --- a/docs_src/dependencies/tutorial013c_an_py39.py +++ b/docs_src/dependencies/tutorial013c_an_py39.py @@ -10,6 +10,7 @@ class MyDatabaseConnection: """ This is a mock just for example purposes. """ + connection_string: str async def __aenter__(self) -> Self: @@ -33,20 +34,25 @@ async def get_configuration() -> dict: "database_url": "sqlite:///database.db", } -GlobalConfiguration = Annotated[dict, Depends(get_configuration, dependency_scope="lifespan")] + +GlobalConfiguration = Annotated[ + dict, Depends(get_configuration, dependency_scope="lifespan") +] async def get_database_connection(configuration: GlobalConfiguration): - async with MyDatabaseConnection( - configuration["database_url"]) as connection: + async with MyDatabaseConnection(configuration["database_url"]) as connection: yield connection -GlobalDatabaseConnection = Annotated[get_database_connection, Depends(get_database_connection, dependency_scope="lifespan")] + +GlobalDatabaseConnection = Annotated[ + get_database_connection, + Depends(get_database_connection, dependency_scope="lifespan"), +] @app.get("/users/{user_id}") async def read_user( - database_connection: GlobalDatabaseConnection, - user_id: Annotated[str, Path()] + database_connection: GlobalDatabaseConnection, user_id: Annotated[str, Path()] ): return await database_connection.get_record("users", user_id) diff --git a/docs_src/dependencies/tutorial013d.py b/docs_src/dependencies/tutorial013d.py index dc6441620..dd04b37bc 100644 --- a/docs_src/dependencies/tutorial013d.py +++ b/docs_src/dependencies/tutorial013d.py @@ -21,6 +21,7 @@ class MyDatabaseConnection: async def get_record(self, table_name: str, record_id: str) -> dict: pass + app = FastAPI() @@ -33,14 +34,12 @@ GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="li async def get_user_record( - database_connection: MyDatabaseConnection = GlobalDatabaseConnection, - user_id: str = Path() + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, + user_id: str = Path(), ) -> dict: return await database_connection.get_record("users", user_id) @app.get("/users/{user_id}") -async def read_user( - user_record: dict = Depends(get_user_record) -): +async def read_user(user_record: dict = Depends(get_user_record)): return user_record diff --git a/docs_src/dependencies/tutorial013d_an_py39.py b/docs_src/dependencies/tutorial013d_an_py39.py index 41aae4c47..57ef5676a 100644 --- a/docs_src/dependencies/tutorial013d_an_py39.py +++ b/docs_src/dependencies/tutorial013d_an_py39.py @@ -21,6 +21,7 @@ class MyDatabaseConnection: async def get_record(self, table_name: str, record_id: str) -> dict: pass + app = FastAPI() @@ -29,17 +30,17 @@ async def get_database_connection(): yield connection -GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")] +GlobalDatabaseConnection = Annotated[ + MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan") +] async def get_user_record( - database_connection: GlobalDatabaseConnection, - user_id: Annotated[str, Path()] + database_connection: GlobalDatabaseConnection, user_id: Annotated[str, Path()] ) -> dict: return await database_connection.get_record("users", user_id) + @app.get("/users/{user_id}") -async def read_user( - user_record: Annotated[dict, Depends(get_user_record)] -): +async def read_user(user_record: Annotated[dict, Depends(get_user_record)]): return user_record diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013a.py b/tests/test_tutorial/test_dependencies/test_tutorial013a.py index d8a6884c0..7b5d823f9 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013a.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013a.py @@ -34,11 +34,7 @@ class MockDatabaseConnection: def database_connection_mock(monkeypatch) -> MockDatabaseConnection: mock = MockDatabaseConnection() - monkeypatch.setattr( - MyDatabaseConnection, - "__new__", - lambda *args, **kwargs: mock - ) + monkeypatch.setattr(MyDatabaseConnection, "__new__", lambda *args, **kwargs: mock) return mock @@ -50,13 +46,13 @@ def test_dependency_usage(database_connection_mock): assert database_connection_mock.enter_count == 1 assert database_connection_mock.exit_count == 0 - response = test_client.get('/users') + response = test_client.get("/users") assert response.status_code == 200 assert response.json() == [] assert database_connection_mock.get_records_count == 1 - response = test_client.get('/items') + response = test_client.get("/items") assert response.status_code == 200 assert response.json() == [] diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py index 05af6b665..cf8c7326f 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py @@ -34,11 +34,7 @@ class MockDatabaseConnection: def database_connection_mock(monkeypatch) -> MockDatabaseConnection: mock = MockDatabaseConnection() - monkeypatch.setattr( - MyDatabaseConnection, - "__new__", - lambda *args, **kwargs: mock - ) + monkeypatch.setattr(MyDatabaseConnection, "__new__", lambda *args, **kwargs: mock) return mock @@ -50,13 +46,13 @@ def test_dependency_usage(database_connection_mock): assert database_connection_mock.enter_count == 1 assert database_connection_mock.exit_count == 0 - response = test_client.get('/users') + response = test_client.get("/users") assert response.status_code == 200 assert response.json() == [] assert database_connection_mock.get_records_count == 1 - response = test_client.get('/items') + response = test_client.get("/items") assert response.status_code == 200 assert response.json() == [] diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013b.py b/tests/test_tutorial/test_dependencies/test_tutorial013b.py index 88942d0f3..084f0bffc 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013b.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013b.py @@ -40,22 +40,17 @@ class MockDatabaseConnection: } - @pytest.fixture def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: connections = [] + def _get_new_connection_mock(*args, **kwargs): mock = MockDatabaseConnection() connections.append(mock) return mock - - monkeypatch.setattr( - MyDatabaseConnection, - "__new__", - _get_new_connection_mock - ) + monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) return connections @@ -70,7 +65,7 @@ def test_dependency_usage(database_connection_mocks): assert connection.get_records_count == 0 assert connection.get_record_count == 0 - response = test_client.get('/users') + response = test_client.get("/users") assert response.status_code == 200 assert response.json() == [] @@ -80,9 +75,11 @@ def test_dependency_usage(database_connection_mocks): users_connection = connection break - assert users_connection is not None, "No connection was found for users endpoint" + assert ( + users_connection is not None + ), "No connection was found for users endpoint" - response = test_client.get('/groups') + response = test_client.get("/groups") assert response.status_code == 200 assert response.json() == [] @@ -92,7 +89,9 @@ def test_dependency_usage(database_connection_mocks): groups_connection = connection break - assert groups_connection is not None, "No connection was found for groups endpoint" + assert ( + groups_connection is not None + ), "No connection was found for groups endpoint" assert groups_connection.get_records_count == 1 items_connection = None @@ -101,16 +100,18 @@ def test_dependency_usage(database_connection_mocks): items_connection = connection break - assert items_connection is not None, "No connection was found for items endpoint" + assert ( + items_connection is not None + ), "No connection was found for items endpoint" - response = test_client.get('/items') + response = test_client.get("/items") assert response.status_code == 200 assert response.json() == [] assert items_connection.get_records_count == 1 assert items_connection.get_record_count == 0 - response = test_client.get('/items/asd') + response = test_client.get("/items/asd") assert response.status_code == 200 assert response.json() == { "table_name": "items", diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py index 6c0c2132f..46e0c68d0 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py @@ -40,22 +40,17 @@ class MockDatabaseConnection: } - @pytest.fixture def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: connections = [] + def _get_new_connection_mock(*args, **kwargs): mock = MockDatabaseConnection() connections.append(mock) return mock - - monkeypatch.setattr( - MyDatabaseConnection, - "__new__", - _get_new_connection_mock - ) + monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) return connections @@ -70,7 +65,7 @@ def test_dependency_usage(database_connection_mocks): assert connection.get_records_count == 0 assert connection.get_record_count == 0 - response = test_client.get('/users') + response = test_client.get("/users") assert response.status_code == 200 assert response.json() == [] @@ -80,9 +75,11 @@ def test_dependency_usage(database_connection_mocks): users_connection = connection break - assert users_connection is not None, "No connection was found for users endpoint" + assert ( + users_connection is not None + ), "No connection was found for users endpoint" - response = test_client.get('/groups') + response = test_client.get("/groups") assert response.status_code == 200 assert response.json() == [] @@ -92,7 +89,9 @@ def test_dependency_usage(database_connection_mocks): groups_connection = connection break - assert groups_connection is not None, "No connection was found for groups endpoint" + assert ( + groups_connection is not None + ), "No connection was found for groups endpoint" assert groups_connection.get_records_count == 1 items_connection = None @@ -101,16 +100,18 @@ def test_dependency_usage(database_connection_mocks): items_connection = connection break - assert items_connection is not None, "No connection was found for items endpoint" + assert ( + items_connection is not None + ), "No connection was found for items endpoint" - response = test_client.get('/items') + response = test_client.get("/items") assert response.status_code == 200 assert response.json() == [] assert items_connection.get_records_count == 1 assert items_connection.get_record_count == 0 - response = test_client.get('/items/asd') + response = test_client.get("/items/asd") assert response.status_code == 200 assert response.json() == { "table_name": "items", diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013c.py b/tests/test_tutorial/test_dependencies/test_tutorial013c.py index 09390cd99..800eaade7 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013c.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013c.py @@ -34,21 +34,17 @@ class MockDatabaseConnection: } - @pytest.fixture def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: connections = [] + def _get_new_connection_mock(cls, url): mock = MockDatabaseConnection(url) connections.append(mock) return mock - monkeypatch.setattr( - MyDatabaseConnection, - "__new__", - _get_new_connection_mock - ) + monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) return connections @@ -64,7 +60,7 @@ def test_dependency_usage(database_connection_mocks): assert database_connection_mock.exit_count == 0 assert database_connection_mock.get_record_count == 0 - response = test_client.get('/users/user') + response = test_client.get("/users/user") assert response.status_code == 200 assert response.json() == { "table_name": "users", diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py index 03f01145a..705cb701b 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py @@ -34,21 +34,17 @@ class MockDatabaseConnection: } - @pytest.fixture def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: connections = [] + def _get_new_connection_mock(cls, url): mock = MockDatabaseConnection(url) connections.append(mock) return mock - monkeypatch.setattr( - MyDatabaseConnection, - "__new__", - _get_new_connection_mock - ) + monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) return connections @@ -64,7 +60,7 @@ def test_dependency_usage(database_connection_mocks): assert database_connection_mock.exit_count == 0 assert database_connection_mock.get_record_count == 0 - response = test_client.get('/users/user') + response = test_client.get("/users/user") assert response.status_code == 200 assert response.json() == { "table_name": "users", diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013d.py b/tests/test_tutorial/test_dependencies/test_tutorial013d.py index dccf6a006..eb01a7232 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013d.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013d.py @@ -33,21 +33,17 @@ class MockDatabaseConnection: } - @pytest.fixture def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: connections = [] + def _get_new_connection_mock(*args, **kwargs): mock = MockDatabaseConnection() connections.append(mock) return mock - monkeypatch.setattr( - MyDatabaseConnection, - "__new__", - _get_new_connection_mock - ) + monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) return connections @@ -62,7 +58,7 @@ def test_dependency_usage(database_connection_mocks): assert database_connection_mock.exit_count == 0 assert database_connection_mock.get_record_count == 0 - response = test_client.get('/users/user') + response = test_client.get("/users/user") assert response.status_code == 200 assert response.json() == { "table_name": "users", diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py index 4f27fccca..6b1562749 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py @@ -33,21 +33,17 @@ class MockDatabaseConnection: } - @pytest.fixture def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: connections = [] + def _get_new_connection_mock(*args, **kwargs): mock = MockDatabaseConnection() connections.append(mock) return mock - monkeypatch.setattr( - MyDatabaseConnection, - "__new__", - _get_new_connection_mock - ) + monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) return connections @@ -62,7 +58,7 @@ def test_dependency_usage(database_connection_mocks): assert database_connection_mock.exit_count == 0 assert database_connection_mock.get_record_count == 0 - response = test_client.get('/users/user') + response = test_client.get("/users/user") assert response.status_code == 200 assert response.json() == { "table_name": "users", From 4df502eb9923bf168b8895b0ca89bb7c7913cd29 Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 23 Nov 2024 21:21:36 +0200 Subject: [PATCH 18/29] Added missing decorator for python3.9 tests --- .../test_dependencies/test_tutorial013a_an_py39.py | 2 ++ .../test_dependencies/test_tutorial013b_an_py39.py | 2 ++ .../test_dependencies/test_tutorial013c_an_py39.py | 2 ++ .../test_dependencies/test_tutorial013d_an_py39.py | 2 ++ 4 files changed, 8 insertions(+) diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py index 05af6b665..cf5417a82 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py @@ -5,6 +5,7 @@ from starlette.testclient import TestClient from typing_extensions import Self from docs_src.dependencies.tutorial013a_an_py39 import MyDatabaseConnection, app +from ...utils import needs_py39 class MockDatabaseConnection: @@ -43,6 +44,7 @@ def database_connection_mock(monkeypatch) -> MockDatabaseConnection: return mock +@needs_py39 def test_dependency_usage(database_connection_mock): assert database_connection_mock.enter_count == 0 assert database_connection_mock.exit_count == 0 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py index 6c0c2132f..9a1204fe8 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py @@ -5,6 +5,7 @@ from starlette.testclient import TestClient from typing_extensions import Self from docs_src.dependencies.tutorial013b_an_py39 import MyDatabaseConnection, app +from ...utils import needs_py39 class MockDatabaseConnection: @@ -59,6 +60,7 @@ def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: return connections +@needs_py39 def test_dependency_usage(database_connection_mocks): assert len(database_connection_mocks) == 0 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py index 03f01145a..10d7404da 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py @@ -5,6 +5,7 @@ from starlette.testclient import TestClient from typing_extensions import Self from docs_src.dependencies.tutorial013c_an_py39 import MyDatabaseConnection, app +from ...utils import needs_py39 class MockDatabaseConnection: @@ -52,6 +53,7 @@ def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: return connections +@needs_py39 def test_dependency_usage(database_connection_mocks): assert len(database_connection_mocks) == 0 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py index 4f27fccca..48df82582 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py @@ -5,6 +5,7 @@ from starlette.testclient import TestClient from typing_extensions import Self from docs_src.dependencies.tutorial013d_an_py39 import MyDatabaseConnection, app +from ...utils import needs_py39 class MockDatabaseConnection: @@ -51,6 +52,7 @@ def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: return connections +@needs_py39 def test_dependency_usage(database_connection_mocks): assert len(database_connection_mocks) == 0 From 17873cf291e74670cf086ca20be460a58e5aca9d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 23 Nov 2024 19:22:01 +0000 Subject: [PATCH 19/29] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_tutorial/test_dependencies/test_tutorial013a_an_py39.py | 1 + .../test_tutorial/test_dependencies/test_tutorial013b_an_py39.py | 1 + .../test_tutorial/test_dependencies/test_tutorial013c_an_py39.py | 1 + .../test_tutorial/test_dependencies/test_tutorial013d_an_py39.py | 1 + 4 files changed, 4 insertions(+) diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py index 4f4751ed9..d755e9c4a 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py @@ -5,6 +5,7 @@ from starlette.testclient import TestClient from typing_extensions import Self from docs_src.dependencies.tutorial013a_an_py39 import MyDatabaseConnection, app + from ...utils import needs_py39 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py index 2015ee2a9..f19593563 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py @@ -5,6 +5,7 @@ from starlette.testclient import TestClient from typing_extensions import Self from docs_src.dependencies.tutorial013b_an_py39 import MyDatabaseConnection, app + from ...utils import needs_py39 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py index ff6799e2d..5f9b420bb 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py @@ -5,6 +5,7 @@ from starlette.testclient import TestClient from typing_extensions import Self from docs_src.dependencies.tutorial013c_an_py39 import MyDatabaseConnection, app + from ...utils import needs_py39 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py index 386422ce2..3f7965894 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py @@ -5,6 +5,7 @@ from starlette.testclient import TestClient from typing_extensions import Self from docs_src.dependencies.tutorial013d_an_py39 import MyDatabaseConnection, app + from ...utils import needs_py39 From aaf309f8e1f7a0b59f10d0e8439bc9071065e39d Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 23 Nov 2024 21:25:51 +0200 Subject: [PATCH 20/29] Added missing reference to lifespan scoped dependencies page in the learn tab --- docs/en/mkdocs.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/en/mkdocs.yml b/docs/en/mkdocs.yml index 6443b290a..e22ee373c 100644 --- a/docs/en/mkdocs.yml +++ b/docs/en/mkdocs.yml @@ -144,6 +144,7 @@ nav: - tutorial/dependencies/dependencies-in-path-operation-decorators.md - tutorial/dependencies/global-dependencies.md - tutorial/dependencies/dependencies-with-yield.md + - tutorial/dependencies/lifespan-scoped-dependencies.md - Security: - tutorial/security/index.md - tutorial/security/first-steps.md From b7864a611cbe1cbe31c384546eef3ac36062d6c6 Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 23 Nov 2024 21:32:50 +0200 Subject: [PATCH 21/29] Updated tutorial tests to only import for valid python versions --- .../test_dependencies/test_tutorial013a_an_py39.py | 4 +++- .../test_dependencies/test_tutorial013b_an_py39.py | 4 +++- .../test_dependencies/test_tutorial013c_an_py39.py | 4 +++- .../test_dependencies/test_tutorial013d_an_py39.py | 4 +++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py index d755e9c4a..90775b6b0 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py @@ -1,10 +1,12 @@ +import sys from typing import List import pytest from starlette.testclient import TestClient from typing_extensions import Self -from docs_src.dependencies.tutorial013a_an_py39 import MyDatabaseConnection, app +if sys.version_info >= (3, 9): + from docs_src.dependencies.tutorial013a_an_py39 import MyDatabaseConnection, app from ...utils import needs_py39 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py index f19593563..b12fa0637 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py @@ -1,10 +1,12 @@ +import sys from typing import List import pytest from starlette.testclient import TestClient from typing_extensions import Self -from docs_src.dependencies.tutorial013b_an_py39 import MyDatabaseConnection, app +if sys.version_info >= (3, 9): + from docs_src.dependencies.tutorial013b_an_py39 import MyDatabaseConnection, app from ...utils import needs_py39 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py index 5f9b420bb..80ac67f42 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py @@ -1,10 +1,12 @@ +import sys from typing import List import pytest from starlette.testclient import TestClient from typing_extensions import Self -from docs_src.dependencies.tutorial013c_an_py39 import MyDatabaseConnection, app +if sys.version_info >= (3, 9): + from docs_src.dependencies.tutorial013c_an_py39 import MyDatabaseConnection, app from ...utils import needs_py39 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py index 3f7965894..8563325da 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py @@ -1,10 +1,12 @@ +import sys from typing import List import pytest from starlette.testclient import TestClient from typing_extensions import Self -from docs_src.dependencies.tutorial013d_an_py39 import MyDatabaseConnection, app +if sys.version_info >= (3, 9): + from docs_src.dependencies.tutorial013d_an_py39 import MyDatabaseConnection, app from ...utils import needs_py39 From 543c2622c407f5ec8dee10aaae1f9e59ddaec688 Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 23 Nov 2024 21:42:33 +0200 Subject: [PATCH 22/29] Removed unintended example highlighting --- .../tutorial/dependencies/lifespan-scoped-dependencies.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md index 2a3585cf9..5f2d14131 100644 --- a/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md +++ b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md @@ -26,7 +26,7 @@ it. You may declare a dependency as a lifespan scoped dependency by passing `dependency_scope="lifespan"` to the `Depends` function: -{* ../../docs_src/dependencies/tutorial013a_an_py39.py hl[16] *} +{* ../../docs_src/dependencies/tutorial013a_an_py39.py *} /// tip @@ -52,7 +52,7 @@ scoped dependencies. Meaning as **FastAPI** gathers lifespan scoped dependencies will cache dependencies it already encountered before. However, you can disable this behavior by passing `use_cache=False` to `Depends`: -{* ../../docs_src/dependencies/tutorial013b_an_py39.py hl[16] *} +{* ../../docs_src/dependencies/tutorial013b_an_py39.py *} In this example, the `read_users` and `read_groups` endpoints are using `use_cache=False` whereas the `read_items` and `read_item` are using @@ -68,11 +68,11 @@ third and final connection will be shared across all requests for the Just like with endpoint scoped dependencies, lifespan scoped dependencies may use other lifespan scoped sub-dependencies themselves: -{* ../../docs_src/dependencies/tutorial013c_an_py39.py hl[16] *} +{* ../../docs_src/dependencies/tutorial013c_an_py39.py *} Endpoint scoped dependencies may use lifespan scoped sub dependencies as well: -{* ../../docs_src/dependencies/tutorial013d_an_py39.py hl[16] *} +{* ../../docs_src/dependencies/tutorial013d_an_py39.py *} /// note From df180870b1e62df45e3db0b90c32563f2ceaeab8 Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 23 Nov 2024 21:57:14 +0200 Subject: [PATCH 23/29] Reworded some sections and fixed certain formatting issues --- .../lifespan-scoped-dependencies.md | 50 +++++++++++-------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md index 5f2d14131..817cbbf4d 100644 --- a/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md +++ b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md @@ -1,24 +1,34 @@ # Lifespan Scoped Dependencies +## Intro + So far we've used dependencies which are "endpoint scoped". Meaning, they are called again and again for every incoming request to the endpoint. However, -this is not ideal for all kinds of dependencies. +this is not always ideal: + +* Sometimes dependencies have a large setup/teardown time. Running it for every request will result in bad performance. +* Sometimes dependencies need to have their values shared throughout the lifespan +of the application between multiple requests. -Sometimes dependencies have a large setup/teardown time, or there is a need -for their value to be shared throughout the lifespan of the application. An -example of this would be a connection to a database. Databases are typically + +An example of this would be a connection to a database. Databases are typically less efficient when working with lots of connections and would prefer that clients would create a single connection for their operations. -For such cases, you might want to use "lifespan scoped" dependencies. +For such cases can be solved by using "lifespan scoped dependencies". -## Intro -Lifespan scoped dependencies work similarly to the dependencies we've worked -with so far (which are endpoint scoped). However, they are called once and only -once in the application's lifespan (instead of being called again and again for -every request). The returned value will be shared across all requests that need -it. +## What is a lifespan scoped dependency? +Lifespan scoped dependencies work similarly to the (endpoint scoped) +dependencies we've worked with so far. However, unlike endpoint scoped +dependencies, lifespan scoped dependencies are called once and only +once in the application's lifespan: + +* During the application startup process, all lifespan scoped dependencies will +be called. +* Their returned value will be shared across all requests to the application. +* During the application's shutdown process, all lifespan scoped dependencies +will be gracefully teared down. ## Create a lifespan scoped dependency @@ -56,12 +66,12 @@ this behavior by passing `use_cache=False` to `Depends`: In this example, the `read_users` and `read_groups` endpoints are using `use_cache=False` whereas the `read_items` and `read_item` are using -`use_cache=True`. That means that we'll have a total of 3 connections created -for the duration of the application's lifespan. One connection will be shared -across all requests for the `read_items` and `read_item` endpoints. A second -connection will be shared across all requests for the `read_users` endpoint. The -third and final connection will be shared across all requests for the -`read_groups` endpoint. +`use_cache=True`. +That means that we'll have a total of 3 connections created +for the duration of the application's lifespan: +* One connection will be shared across all requests for the `read_items` and `read_item` endpoints. +* A second connection will be shared across all requests for the `read_users` endpoint. +* A third and final connection will be shared across all requests for the `read_groups` endpoint. ## Lifespan Scoped Sub-Dependencies @@ -92,8 +102,8 @@ Therefore, it is not possible for a lifespan scoped dependency to use any parameters that require the scope of an endpoint. That includes but not limited to: - * Parts of the request (like `Body`, `Query` and `Path`) - * The request/response objects themselves (like `Request`, `Response` and `WebSocket`) - * Endpoint scoped sub-dependencies. +* Parts of the request (like `Body`, `Query` and `Path`) +* The request/response objects themselves (like `Request`, `Response` and `WebSocket`) +* Endpoint scoped sub-dependencies. Defining a dependency with such parameters will raise an `InvalidDependencyScope` error. From 31e2d6a21e57694911f40bf2cd744fd3ee3b8223 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 23 Nov 2024 19:57:25 +0000 Subject: [PATCH 24/29] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../lifespan-scoped-dependencies.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md index 817cbbf4d..20098b1be 100644 --- a/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md +++ b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md @@ -6,11 +6,11 @@ So far we've used dependencies which are "endpoint scoped". Meaning, they are called again and again for every incoming request to the endpoint. However, this is not always ideal: -* Sometimes dependencies have a large setup/teardown time. Running it for every request will result in bad performance. -* Sometimes dependencies need to have their values shared throughout the lifespan +* Sometimes dependencies have a large setup/teardown time. Running it for every request will result in bad performance. +* Sometimes dependencies need to have their values shared throughout the lifespan of the application between multiple requests. - + An example of this would be a connection to a database. Databases are typically less efficient when working with lots of connections and would prefer that clients would create a single connection for their operations. @@ -19,12 +19,12 @@ For such cases can be solved by using "lifespan scoped dependencies". ## What is a lifespan scoped dependency? -Lifespan scoped dependencies work similarly to the (endpoint scoped) -dependencies we've worked with so far. However, unlike endpoint scoped +Lifespan scoped dependencies work similarly to the (endpoint scoped) +dependencies we've worked with so far. However, unlike endpoint scoped dependencies, lifespan scoped dependencies are called once and only once in the application's lifespan: -* During the application startup process, all lifespan scoped dependencies will +* During the application startup process, all lifespan scoped dependencies will be called. * Their returned value will be shared across all requests to the application. * During the application's shutdown process, all lifespan scoped dependencies @@ -66,11 +66,11 @@ this behavior by passing `use_cache=False` to `Depends`: In this example, the `read_users` and `read_groups` endpoints are using `use_cache=False` whereas the `read_items` and `read_item` are using -`use_cache=True`. +`use_cache=True`. That means that we'll have a total of 3 connections created for the duration of the application's lifespan: -* One connection will be shared across all requests for the `read_items` and `read_item` endpoints. -* A second connection will be shared across all requests for the `read_users` endpoint. +* One connection will be shared across all requests for the `read_items` and `read_item` endpoints. +* A second connection will be shared across all requests for the `read_users` endpoint. * A third and final connection will be shared across all requests for the `read_groups` endpoint. From 73f56985aa4e83146c56e1567e5f4d9688e3d2ab Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 23 Nov 2024 22:01:22 +0200 Subject: [PATCH 25/29] Fixed coverage --- docs_src/dependencies/tutorial013c_an_py39.py | 3 --- docs_src/dependencies/tutorial013d.py | 3 --- docs_src/dependencies/tutorial013d_an_py39.py | 3 --- tests/test_tutorial/test_dependencies/test_tutorial013b.py | 2 +- .../test_dependencies/test_tutorial013b_an_py39.py | 2 +- 5 files changed, 2 insertions(+), 11 deletions(-) diff --git a/docs_src/dependencies/tutorial013c_an_py39.py b/docs_src/dependencies/tutorial013c_an_py39.py index a111f1b6d..be945b7d2 100644 --- a/docs_src/dependencies/tutorial013c_an_py39.py +++ b/docs_src/dependencies/tutorial013c_an_py39.py @@ -19,9 +19,6 @@ class MyDatabaseConnection: async def __aexit__(self, exc_type, exc_val, exc_tb): pass - async def get_records(self, table_name: str) -> List[dict]: - pass - async def get_record(self, table_name: str, record_id: str) -> dict: pass diff --git a/docs_src/dependencies/tutorial013d.py b/docs_src/dependencies/tutorial013d.py index dd04b37bc..471bb2004 100644 --- a/docs_src/dependencies/tutorial013d.py +++ b/docs_src/dependencies/tutorial013d.py @@ -15,9 +15,6 @@ class MyDatabaseConnection: async def __aexit__(self, exc_type, exc_val, exc_tb): pass - async def get_records(self, table_name: str) -> List[dict]: - pass - async def get_record(self, table_name: str, record_id: str) -> dict: pass diff --git a/docs_src/dependencies/tutorial013d_an_py39.py b/docs_src/dependencies/tutorial013d_an_py39.py index 57ef5676a..fa6b0831b 100644 --- a/docs_src/dependencies/tutorial013d_an_py39.py +++ b/docs_src/dependencies/tutorial013d_an_py39.py @@ -15,9 +15,6 @@ class MyDatabaseConnection: async def __aexit__(self, exc_type, exc_val, exc_tb): pass - async def get_records(self, table_name: str) -> list[dict]: - pass - async def get_record(self, table_name: str, record_id: str) -> dict: pass diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013b.py b/tests/test_tutorial/test_dependencies/test_tutorial013b.py index 084f0bffc..4483a3a17 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013b.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013b.py @@ -33,7 +33,7 @@ class MockDatabaseConnection: async def get_record(self, table_name: str, record_id: str) -> dict: self.get_record_count += 1 # Called for the sake of coverage. - await MyDatabaseConnection.get_records(self, table_name) + await MyDatabaseConnection.get_record(self, table_name, record_id) return { "table_name": table_name, "record_id": record_id, diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py index b12fa0637..6a489199f 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py @@ -37,7 +37,7 @@ class MockDatabaseConnection: async def get_record(self, table_name: str, record_id: str) -> dict: self.get_record_count += 1 # Called for the sake of coverage. - await MyDatabaseConnection.get_records(self, table_name) + await MyDatabaseConnection.get_record(self, table_name, record_id) return { "table_name": table_name, "record_id": record_id, From 4733cc96a4cdbae3632781fc4ec975bc347e1582 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 23 Nov 2024 20:01:54 +0000 Subject: [PATCH 26/29] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs_src/dependencies/tutorial013c_an_py39.py | 2 +- docs_src/dependencies/tutorial013d.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docs_src/dependencies/tutorial013c_an_py39.py b/docs_src/dependencies/tutorial013c_an_py39.py index be945b7d2..a64e72b8a 100644 --- a/docs_src/dependencies/tutorial013c_an_py39.py +++ b/docs_src/dependencies/tutorial013c_an_py39.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Annotated, List +from typing import Annotated from fastapi import Depends, FastAPI, Path from typing_extensions import Self diff --git a/docs_src/dependencies/tutorial013d.py b/docs_src/dependencies/tutorial013d.py index 471bb2004..01e2831d7 100644 --- a/docs_src/dependencies/tutorial013d.py +++ b/docs_src/dependencies/tutorial013d.py @@ -1,5 +1,3 @@ -from typing import List - from fastapi import Depends, FastAPI, Path from typing_extensions import Self From 98232099302dc0eb0c8163884164443c1e90f46f Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 23 Nov 2024 22:16:34 +0200 Subject: [PATCH 27/29] Fixed more formatting issues --- .../docs/tutorial/dependencies/lifespan-scoped-dependencies.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md index 20098b1be..ba2c162aa 100644 --- a/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md +++ b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md @@ -102,6 +102,7 @@ Therefore, it is not possible for a lifespan scoped dependency to use any parameters that require the scope of an endpoint. That includes but not limited to: + * Parts of the request (like `Body`, `Query` and `Path`) * The request/response objects themselves (like `Request`, `Response` and `WebSocket`) * Endpoint scoped sub-dependencies. From 8e66dfcfb20fe38d56834d24bb2089b023b192bb Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 23 Nov 2024 22:17:13 +0200 Subject: [PATCH 28/29] Fixed more formatting issues --- .../docs/tutorial/dependencies/lifespan-scoped-dependencies.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md index ba2c162aa..ba46da330 100644 --- a/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md +++ b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md @@ -69,6 +69,7 @@ In this example, the `read_users` and `read_groups` endpoints are using `use_cache=True`. That means that we'll have a total of 3 connections created for the duration of the application's lifespan: + * One connection will be shared across all requests for the `read_items` and `read_item` endpoints. * A second connection will be shared across all requests for the `read_users` endpoint. * A third and final connection will be shared across all requests for the `read_groups` endpoint. From 731b93202c6c5865875b39e8c2ee61a80d930032 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jul 2025 14:34:49 +0000 Subject: [PATCH 29/29] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/dependencies/models.py | 6 +++--- .../test_dependencies/test_tutorial013b.py | 18 +++++++++--------- .../test_tutorial013b_an_py39.py | 18 +++++++++--------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 9bbd81d37..1f4192a51 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -33,9 +33,9 @@ class LifespanDependant: elif self.name is not None: self.cache_key = (self.caller, self.name) else: - assert ( - self.index is not None - ), "Lifespan dependency must have an associated name or index." + assert self.index is not None, ( + "Lifespan dependency must have an associated name or index." + ) self.cache_key = (self.caller, self.index) diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013b.py b/tests/test_tutorial/test_dependencies/test_tutorial013b.py index 4483a3a17..a7a092cf0 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013b.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013b.py @@ -75,9 +75,9 @@ def test_dependency_usage(database_connection_mocks): users_connection = connection break - assert ( - users_connection is not None - ), "No connection was found for users endpoint" + assert users_connection is not None, ( + "No connection was found for users endpoint" + ) response = test_client.get("/groups") assert response.status_code == 200 @@ -89,9 +89,9 @@ def test_dependency_usage(database_connection_mocks): groups_connection = connection break - assert ( - groups_connection is not None - ), "No connection was found for groups endpoint" + assert groups_connection is not None, ( + "No connection was found for groups endpoint" + ) assert groups_connection.get_records_count == 1 items_connection = None @@ -100,9 +100,9 @@ def test_dependency_usage(database_connection_mocks): items_connection = connection break - assert ( - items_connection is not None - ), "No connection was found for items endpoint" + assert items_connection is not None, ( + "No connection was found for items endpoint" + ) response = test_client.get("/items") assert response.status_code == 200 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py index 6a489199f..e782f729f 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py @@ -80,9 +80,9 @@ def test_dependency_usage(database_connection_mocks): users_connection = connection break - assert ( - users_connection is not None - ), "No connection was found for users endpoint" + assert users_connection is not None, ( + "No connection was found for users endpoint" + ) response = test_client.get("/groups") assert response.status_code == 200 @@ -94,9 +94,9 @@ def test_dependency_usage(database_connection_mocks): groups_connection = connection break - assert ( - groups_connection is not None - ), "No connection was found for groups endpoint" + assert groups_connection is not None, ( + "No connection was found for groups endpoint" + ) assert groups_connection.get_records_count == 1 items_connection = None @@ -105,9 +105,9 @@ def test_dependency_usage(database_connection_mocks): items_connection = connection break - assert ( - items_connection is not None - ), "No connection was found for items endpoint" + assert items_connection is not None, ( + "No connection was found for items endpoint" + ) response = test_client.get("/items") assert response.status_code == 200