diff --git a/fastapi/applications.py b/fastapi/applications.py index 6d427cdc2..625690a74 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -1,6 +1,8 @@ +from contextlib import AsyncExitStack, asynccontextmanager from enum import Enum from typing import ( Any, + AsyncGenerator, Awaitable, Callable, Coroutine, @@ -15,12 +17,14 @@ from typing import ( from fastapi import routing from fastapi.datastructures import Default, DefaultPlaceholder +from fastapi.dependencies.utils import is_coroutine_callable from fastapi.exception_handlers import ( http_exception_handler, request_validation_exception_handler, websocket_request_validation_exception_handler, ) from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError +from fastapi.lifespan import resolve_lifespan_dependants from fastapi.logger import logger from fastapi.openapi.docs import ( get_redoc_html, @@ -29,9 +33,11 @@ from fastapi.openapi.docs import ( ) from fastapi.openapi.utils import get_openapi from fastapi.params import Depends +from fastapi.routing import merge_lifespan_context from fastapi.types import DecoratedCallable, IncEx from fastapi.utils import generate_unique_id from starlette.applications import Starlette +from starlette.concurrency import run_in_threadpool from starlette.datastructures import State from starlette.exceptions import HTTPException from starlette.middleware import Middleware @@ -929,12 +935,24 @@ class FastAPI(Starlette): """ ), ] = {} + if lifespan is None: + lifespan = FastAPI._internal_lifespan + else: + lifespan = merge_lifespan_context( + FastAPI._internal_lifespan, + lifespan + ) + + # Since we always use a lifespan, starlette will no longer run event + # handlers which are defined in the scope of the application. + # We therefore need to call them ourselves. + self._on_startup = on_startup or [] + self._on_shutdown = on_shutdown or [] + self.router: routing.APIRouter = routing.APIRouter( routes=routes, redirect_slashes=redirect_slashes, dependency_overrides_provider=self, - on_startup=on_startup, - on_shutdown=on_shutdown, lifespan=lifespan, default_response_class=default_response_class, dependencies=dependencies, @@ -963,6 +981,32 @@ class FastAPI(Starlette): self.middleware_stack: Union[ASGIApp, None] = None self.setup() + @asynccontextmanager + async def _internal_lifespan(self) -> AsyncGenerator[dict[str, Any], None]: + async with AsyncExitStack() as exit_stack: + lifespan_scoped_dependencies = await resolve_lifespan_dependants( + app=self, + async_exit_stack=exit_stack + ) + try: + for handler in self._on_startup: + if is_coroutine_callable(handler): + await handler() + else: + await run_in_threadpool(handler) + yield { + "__fastapi__": { + "lifespan_scoped_dependencies": lifespan_scoped_dependencies + } + } + finally: + for handler in self._on_shutdown: + if is_coroutine_callable(handler): + await handler() + else: + await run_in_threadpool(handler) + + def openapi(self) -> Dict[str, Any]: """ Generate the OpenAPI schema of the application. This is called by FastAPI @@ -4492,7 +4536,13 @@ class FastAPI(Starlette): Read more about it in the [FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/#alternative-events-deprecated). """ - return self.router.on_event(event_type) + def decorator(func: DecoratedCallable) -> DecoratedCallable: + if event_type == "startup": + self._on_startup.append(func) + else: + self._on_shutdown.append(func) + return func + return decorator def middleware( self, diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 418c11725..471f9c402 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -1,8 +1,9 @@ from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from fastapi._compat import ModelField from fastapi.security.base import SecurityBase +from typing_extensions import TypeAlias @dataclass @@ -11,17 +12,41 @@ class SecurityRequirement: scopes: Optional[Sequence[str]] = None +LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], str], Callable[..., Any]] + +@dataclass +class LifespanDependant: + caller: Callable[..., Any] + dependencies: List["LifespanDependant"] = field(default_factory=list) + name: Optional[str] = None + call: Optional[Callable[..., Any]] = None + use_cache: bool = True + cache_key: LifespanDependantCacheKey = field(init=False) + + def __post_init__(self) -> None: + if self.use_cache: + self.cache_key = self.call + else: + self.cache_key = (self.caller, self.name) + + +EndpointDependantCacheKey: TypeAlias = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] + @dataclass -class Dependant: +class EndpointDependant: + endpoint_dependencies: List["EndpointDependant"] = field(default_factory=list) + lifespan_dependencies: List[LifespanDependant] = field(default_factory=list) + name: Optional[str] = None + call: Optional[Callable[..., Any]] = None + use_cache: bool = True + cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field( + init=False) path_params: List[ModelField] = field(default_factory=list) query_params: List[ModelField] = field(default_factory=list) header_params: List[ModelField] = field(default_factory=list) cookie_params: List[ModelField] = field(default_factory=list) body_params: List[ModelField] = field(default_factory=list) - dependencies: List["Dependant"] = field(default_factory=list) security_requirements: List[SecurityRequirement] = field(default_factory=list) - name: Optional[str] = None - call: Optional[Callable[..., Any]] = None request_param_name: Optional[str] = None websocket_param_name: Optional[str] = None http_connection_param_name: Optional[str] = None @@ -29,9 +54,16 @@ class Dependant: background_tasks_param_name: Optional[str] = None security_scopes_param_name: Optional[str] = None security_scopes: Optional[List[str]] = None - use_cache: bool = True path: Optional[str] = None - cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) def __post_init__(self) -> None: self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) + + # Kept for backwards compatibility + @property + def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]: + return tuple(self.endpoint_dependencies + self.lifespan_dependencies) + +# Kept for backwards compatibility +Dependant = EndpointDependant +CacheKey: TypeAlias = Union[EndpointDependantCacheKey, LifespanDependantCacheKey] diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 5cebbf00f..c58232b6d 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -51,7 +51,14 @@ from fastapi.concurrency import ( asynccontextmanager, contextmanager_in_threadpool, ) -from fastapi.dependencies.models import Dependant, SecurityRequirement +from fastapi.dependencies.models import ( + CacheKey, + EndpointDependant, + LifespanDependant, + LifespanDependantCacheKey, + SecurityRequirement, +) +from fastapi.exceptions import FastAPIError from fastapi.logger import logger from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes @@ -112,8 +119,9 @@ def get_param_sub_dependant( param_name: str, depends: params.Depends, path: str, + caller: Callable[..., Any], security_scopes: Optional[List[str]] = None, -) -> Dependant: +) -> Union[EndpointDependant, LifespanDependant]: assert depends.dependency return get_sub_dependant( depends=depends, @@ -121,14 +129,25 @@ def get_param_sub_dependant( path=path, name=param_name, security_scopes=security_scopes, + caller=caller, ) -def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: +def get_parameterless_sub_dependant( + *, + depends: params.Depends, + path: str, + caller: Callable[..., Any] +) -> Union[EndpointDependant, LifespanDependant]: assert callable( depends.dependency ), "A parameter-less dependency must have a callable dependency" - return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) + return get_sub_dependant( + depends=depends, + dependency=depends.dependency, + path=path, + caller=caller + ) def get_sub_dependant( @@ -136,57 +155,69 @@ def get_sub_dependant( depends: params.Depends, dependency: Callable[..., Any], path: str, + caller: Callable[..., Any], name: Optional[str] = None, security_scopes: Optional[List[str]] = None, -) -> Dependant: - security_requirement = None - security_scopes = security_scopes or [] - if isinstance(depends, params.Security): - dependency_scopes = depends.scopes - security_scopes.extend(dependency_scopes) - if isinstance(dependency, SecurityBase): - use_scopes: List[str] = [] - if isinstance(dependency, (OAuth2, OpenIdConnect)): - use_scopes = security_scopes - security_requirement = SecurityRequirement( - security_scheme=dependency, scopes=use_scopes +) -> Union[EndpointDependant, LifespanDependant]: + if depends.dependency_scope == "lifespan": + return get_lifespan_dependant( + caller=caller, + call=depends.dependency, + name=name, + use_cache=depends.use_cache + ) + elif depends.dependency_scope == "endpoint": + security_requirement = None + security_scopes = security_scopes or [] + if isinstance(depends, params.Security): + dependency_scopes = depends.scopes + security_scopes.extend(dependency_scopes) + if isinstance(dependency, SecurityBase): + use_scopes: List[str] = [] + if isinstance(dependency, (OAuth2, OpenIdConnect)): + use_scopes = security_scopes + security_requirement = SecurityRequirement( + security_scheme=dependency, scopes=use_scopes + ) + sub_dependant = get_endpoint_dependant( + path=path, + call=dependency, + name=name, + security_scopes=security_scopes, + use_cache=depends.use_cache, + ) + if security_requirement: + sub_dependant.security_requirements.append(security_requirement) + return sub_dependant + else: + raise ValueError( + f"Dependency {name} of {caller} has an invalid " + f"sub-dependency scope: {depends.dependency_scope}" ) - sub_dependant = get_dependant( - path=path, - call=dependency, - name=name, - security_scopes=security_scopes, - use_cache=depends.use_cache, - ) - if security_requirement: - sub_dependant.security_requirements.append(security_requirement) - return sub_dependant - - -CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] def get_flat_dependant( - dependant: Dependant, + dependant: EndpointDependant, *, skip_repeats: bool = False, visited: Optional[List[CacheKey]] = None, -) -> Dependant: +) -> EndpointDependant: if visited is None: visited = [] visited.append(dependant.cache_key) - flat_dependant = Dependant( + flat_dependant = EndpointDependant( path_params=dependant.path_params.copy(), query_params=dependant.query_params.copy(), header_params=dependant.header_params.copy(), cookie_params=dependant.cookie_params.copy(), body_params=dependant.body_params.copy(), security_requirements=dependant.security_requirements.copy(), + lifespan_dependencies=dependant.lifespan_dependencies.copy(), use_cache=dependant.use_cache, - path=dependant.path, + path=dependant.path ) - for sub_dependant in dependant.dependencies: + for sub_dependant in dependant.endpoint_dependencies: if skip_repeats and sub_dependant.cache_key in visited: continue flat_sub = get_flat_dependant( @@ -198,6 +229,7 @@ def get_flat_dependant( flat_dependant.cookie_params.extend(flat_sub.cookie_params) flat_dependant.body_params.extend(flat_sub.body_params) flat_dependant.security_requirements.extend(flat_sub.security_requirements) + flat_dependant.lifespan_dependencies.extend(flat_sub.lifespan_dependencies) return flat_dependant @@ -211,7 +243,7 @@ def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]: return fields -def get_flat_params(dependant: Dependant) -> List[ModelField]: +def get_flat_params(dependant: EndpointDependant) -> List[ModelField]: flat_dependant = get_flat_dependant(dependant, skip_repeats=True) path_params = _get_flat_fields_from_params(flat_dependant.path_params) query_params = _get_flat_fields_from_params(flat_dependant.query_params) @@ -254,18 +286,64 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any: return get_typed_annotation(annotation, globalns) -def get_dependant( +def get_lifespan_dependant( + *, + caller: Callable[..., Any], + call: Callable[..., Any], + name: Optional[str] = None, + use_cache: bool = True, +) -> LifespanDependant: + dependency_signature = get_typed_signature(call) + signature_params = dependency_signature.parameters + dependant = LifespanDependant( + call=call, + name=name, + use_cache=use_cache, + caller=caller + ) + for param_name, param in signature_params.items(): + param_details = analyze_param( + param_name=param_name, + annotation=param.annotation, + value=param.default, + is_path_param=False, + ) + if param_details.depends is None: + raise FastAPIError( + f"Lifespan dependency {dependant.name} was defined with an " + f"invalid argument {param_name}. Lifespan dependencies may " + f"only use other lifespan dependencies as arguments.") + + if param_details.depends.dependency_scope != "lifespan": + raise FastAPIError( + "Lifespan dependency may not use " + "sub-dependencies of other scopes." + ) + + sub_dependant = get_lifespan_dependant( + name=param_name, + call=param_details.depends.dependency, + use_cache=param_details.depends.use_cache, + caller=call + ) + dependant.dependencies.append(sub_dependant) + + return dependant + + + +def get_endpoint_dependant( *, path: str, call: Callable[..., Any], name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, -) -> Dependant: +) -> EndpointDependant: path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters - dependant = Dependant( + dependant = EndpointDependant( call=call, name=name, path=path, @@ -281,13 +359,28 @@ def get_dependant( is_path_param=is_path_param, ) if param_details.depends is not None: - sub_dependant = get_param_sub_dependant( - param_name=param_name, - depends=param_details.depends, - path=path, - security_scopes=security_scopes, - ) - dependant.dependencies.append(sub_dependant) + if param_details.depends.dependency_scope == "endpoint": + sub_dependant = get_param_sub_dependant( + param_name=param_name, + depends=param_details.depends, + path=path, + security_scopes=security_scopes, + caller=call, + ) + dependant.endpoint_dependencies.append(sub_dependant) + elif param_details.depends.dependency_scope == "lifespan": + sub_dependant = get_lifespan_dependant( + caller=call, + call=param_details.depends.dependency, + name=param_name, + use_cache=param_details.depends.use_cache, + ) + dependant.lifespan_dependencies.append(sub_dependant) + else: + raise FastAPIError( + f"Dependency \"{param_name}\" of `{call}` has an invalid " + f"sub-dependency scope: \"{param_details.depends.dependency_scope}\"" + ) continue if add_non_field_param_to_dependency( param_name=param_name, @@ -306,8 +399,12 @@ def get_dependant( return dependant +# Kept for backwards compatibility +get_dependant = get_endpoint_dependant + + def add_non_field_param_to_dependency( - *, param_name: str, type_annotation: Any, dependant: Dependant + *, param_name: str, type_annotation: Any, dependant: EndpointDependant ) -> Optional[bool]: if lenient_issubclass(type_annotation, Request): dependant.request_param_name = param_name @@ -501,7 +598,7 @@ def analyze_param( return ParamDetails(type_annotation=type_annotation, depends=depends, field=field) -def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: +def add_param_to_fields(*, field: ModelField, dependant: EndpointDependant) -> None: field_info = field.field_info field_info_in = getattr(field_info, "in_", None) if field_info_in == params.ParamTypes.path: @@ -550,6 +647,82 @@ async def solve_generator( return await stack.enter_async_context(cm) +@dataclass +class SolvedLifespanDependant: + value: Any + dependency_cache: Dict[Callable[..., Any], Any] + + +async def solve_lifespan_dependant( + *, + dependant: LifespanDependant, + dependency_overrides_provider: Optional[Any] = None, + dependency_cache: Optional[Dict[LifespanDependantCacheKey, Callable[..., Any]]] = None, + async_exit_stack: AsyncExitStack, +) -> SolvedLifespanDependant: + dependency_cache = dependency_cache or {} + if dependant.use_cache and dependant.cache_key in dependency_cache: + return SolvedLifespanDependant( + value=dependency_cache[dependant.cache_key], + dependency_cache=dependency_cache, + ) + + dependency_arguments: Dict[str, Any] = {} + sub_dependant: LifespanDependant + for sub_dependant in dependant.dependencies: + sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) + sub_dependant.cache_key = cast( + Callable[..., Any], sub_dependant.cache_key + ) + assert sub_dependant.name, ( + "Lifespan scoped dependencies should not be able to have " + "subdependencies with no name" + ) + + sub_dependant_to_solve = sub_dependant + if ( + dependency_overrides_provider + and dependency_overrides_provider.dependency_overrides + ): + original_call = sub_dependant.call + call = getattr( + dependency_overrides_provider, "dependency_overrides", {} + ).get(original_call, original_call) + sub_dependant_to_solve = get_lifespan_dependant( + call=call, + name=sub_dependant.name, + caller=dependant.call + ) + + solved_sub_dependant = await solve_lifespan_dependant( + dependant=sub_dependant_to_solve, + dependency_overrides_provider=dependency_overrides_provider, + dependency_cache=dependency_cache, + async_exit_stack=async_exit_stack, + ) + dependency_cache.update(solved_sub_dependant.dependency_cache) + dependency_arguments[sub_dependant.name] = solved_sub_dependant.value + + if is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call): + value = await solve_generator( + call=dependant.call, + stack=async_exit_stack, + sub_values=dependency_arguments + ) + elif is_coroutine_callable(dependant.call): + value = await dependant.call(**dependency_arguments) + else: + value = await run_in_threadpool(dependant.call, **dependency_arguments) + + if dependant.cache_key not in dependency_cache: + dependency_cache[dependant.cache_key] = value + + return SolvedLifespanDependant( + value=value, + dependency_cache=dependency_cache, + ) + + @dataclass class SolvedDependency: values: Dict[str, Any] @@ -562,7 +735,7 @@ class SolvedDependency: async def solve_dependencies( *, request: Union[Request, WebSocket], - dependant: Dependant, + dependant: EndpointDependant, body: Optional[Union[Dict[str, Any], FormData]] = None, background_tasks: Optional[StarletteBackgroundTasks] = None, response: Optional[Response] = None, @@ -573,13 +746,35 @@ async def solve_dependencies( ) -> SolvedDependency: values: Dict[str, Any] = {} errors: List[Any] = [] + + for sub_dependant in dependant.lifespan_dependencies: + if sub_dependant.name is None: + continue + try: + lifespan_scoped_dependencies = request.state.__fastapi__[ + "lifespan_scoped_dependencies"] + except AttributeError as e: + raise FastAPIError( + "FastAPI's internal lifespan was not initialized" + ) from e + + try: + value = lifespan_scoped_dependencies[sub_dependant.cache_key] + except KeyError as e: + raise FastAPIError( + f"Dependency {sub_dependant.name} of {dependant.call} " + f"was not initialized." + ) from e + + values[sub_dependant.name] = value + if response is None: response = Response() del response.headers["content-length"] response.status_code = None # type: ignore + dependency_cache = dependency_cache or {} - sub_dependant: Dependant - for sub_dependant in dependant.dependencies: + for sub_dependant in dependant.endpoint_dependencies: sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) sub_dependant.cache_key = cast( Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key @@ -595,7 +790,7 @@ async def solve_dependencies( dependency_overrides_provider, "dependency_overrides", {} ).get(original_call, original_call) use_path: str = sub_dependant.path # type: ignore - use_sub_dependant = get_dependant( + use_sub_dependant = get_endpoint_dependant( path=use_path, call=call, name=sub_dependant.name, @@ -910,7 +1105,7 @@ async def request_body_to_args( def get_body_field( - *, flat_dependant: Dependant, name: str, embed_body_fields: bool + *, flat_dependant: EndpointDependant, name: str, embed_body_fields: bool ) -> Optional[ModelField]: """ Get a ModelField representing the request body for a path operation, combining diff --git a/fastapi/lifespan.py b/fastapi/lifespan.py new file mode 100644 index 000000000..184c943d7 --- /dev/null +++ b/fastapi/lifespan.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Any, Callable, Dict, List + +from fastapi.dependencies.models import LifespanDependant, LifespanDependantCacheKey +from fastapi.dependencies.utils import solve_lifespan_dependant +from fastapi.routing import APIRoute + +if TYPE_CHECKING: + from fastapi import FastAPI + + +def _get_lifespan_dependants(app: FastAPI) -> List[LifespanDependant]: + lifespan_dependants_cache: Dict[LifespanDependantCacheKey, LifespanDependant] = {} + for route in app.router.routes: + if not isinstance(route, APIRoute): + continue + + for sub_dependant in route.lifespan_dependencies: + if sub_dependant.cache_key in lifespan_dependants_cache: + continue + + lifespan_dependants_cache[sub_dependant.cache_key] = sub_dependant + + return list(lifespan_dependants_cache.values()) + + +async def resolve_lifespan_dependants( + *, + app: FastAPI, + async_exit_stack: AsyncExitStack +) -> Dict[LifespanDependantCacheKey, Callable[..., Any]]: + lifespan_dependants = _get_lifespan_dependants(app) + dependency_cache: Dict[LifespanDependantCacheKey, Callable[..., Any]] = {} + for lifespan_dependant in lifespan_dependants: + solved_dependency = await solve_lifespan_dependant( + dependant=lifespan_dependant, + dependency_overrides_provider=app, + dependency_cache=dependency_cache, + async_exit_stack=async_exit_stack + ) + + dependency_cache.update(solved_dependency.dependency_cache) + + return dependency_cache diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 947eca948..425e919cf 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -15,7 +15,7 @@ from fastapi._compat import ( lenient_issubclass, ) from fastapi.datastructures import DefaultPlaceholder -from fastapi.dependencies.models import Dependant +from fastapi.dependencies.models import EndpointDependant from fastapi.dependencies.utils import ( _get_flat_fields_from_params, get_flat_dependant, @@ -75,7 +75,7 @@ status_code_ranges: Dict[str, str] = { def get_openapi_security_definitions( - flat_dependant: Dependant, + flat_dependant: EndpointDependant, ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: security_definitions = {} operation_security = [] @@ -93,7 +93,7 @@ def get_openapi_security_definitions( def _get_openapi_operation_parameters( *, - dependant: Dependant, + dependant: EndpointDependant, schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap, field_mapping: Dict[ diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index 7ddaace25..9ff744414 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -1,8 +1,11 @@ +from __future__ import annotations + from typing import Any, Callable, Dict, List, Optional, Sequence, Union from fastapi import params from fastapi._compat import Undefined from fastapi.openapi.models import Example +from fastapi.params import DependencyScope from typing_extensions import Annotated, Doc, deprecated _Unset: Any = Undefined @@ -2244,6 +2247,33 @@ def Depends( # noqa: N802 """ ), ] = True, + dependency_scope: Annotated[ + DependencyScope, + Doc( + """ + The scope in which the dependency value should be evaluated. Can be + either `"endpoint"` or `"lifespan"`. + + If `dependency_scope` is set to "endpoint" (the default), the + dependency will be setup and teardown for every request. + + If `dependency_scope` is set to `"lifespan"` the dependency would + be setup at the start of the entire application's lifespan. The + evaluated dependency would be then reused across all endpoints. + The dependency would be teared down as a part of the application's + shutdown process. + + Note that dependencies defined with the `"endpoint"` scope may use + sub-dependencies defined with the `"lifespan"` scope, but not the + other way around; + Dependencies defined with the `"lifespan"` scope may not use + sub-dependencies with `"endpoint"` scope, nor can they use + other "endpoint scoped" arguments such as "Path", "Body", "Query", + or any other annotation which does not make sense in a scope of an + application's entire lifespan. + """ + ) + ] = "endpoint" ) -> Any: """ Declare a FastAPI dependency. @@ -2274,7 +2304,7 @@ def Depends( # noqa: N802 return commons ``` """ - return params.Depends(dependency=dependency, use_cache=use_cache) + return params.Depends(dependency=dependency, use_cache=use_cache, dependency_scope=dependency_scope) def Security( # noqa: N802 diff --git a/fastapi/params.py b/fastapi/params.py index 90ca7cb01..f655acba8 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -4,11 +4,12 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union from fastapi.openapi.models import Example from pydantic.fields import FieldInfo -from typing_extensions import Annotated, deprecated +from typing_extensions import Annotated, Literal, TypeAlias, deprecated from ._compat import PYDANTIC_V2, PYDANTIC_VERSION, Undefined _Unset: Any = Undefined +DependencyScope: TypeAlias = Literal["endpoint", "lifespan"] class ParamTypes(Enum): @@ -759,15 +760,25 @@ class File(Form): class Depends: def __init__( - self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True + self, + dependency: Optional[Callable[..., Any]] = None, + *, + use_cache: bool = True, + dependency_scope: DependencyScope = "endpoint" ): self.dependency = dependency self.use_cache = use_cache + self.dependency_scope = dependency_scope def __repr__(self) -> str: attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) cache = "" if self.use_cache else ", use_cache=False" - return f"{self.__class__.__name__}({attr}{cache})" + if self.dependency_scope == "endpoint": + dependency_scope = "" + else: + dependency_scope = f", dependency_scope=\"{self.dependency_scope}\"" + + return f"{self.__class__.__name__}({attr}{cache}{dependency_scope})" class Security(Depends): @@ -778,5 +789,9 @@ class Security(Depends): scopes: Optional[Sequence[str]] = None, use_cache: bool = True, ): - super().__init__(dependency=dependency, use_cache=use_cache) + super().__init__( + dependency=dependency, + use_cache=use_cache, + dependency_scope="endpoint" + ) self.scopes = scopes or [] diff --git a/fastapi/routing.py b/fastapi/routing.py index 86e303602..c3f446037 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -31,11 +31,11 @@ from fastapi._compat import ( lenient_issubclass, ) from fastapi.datastructures import Default, DefaultPlaceholder -from fastapi.dependencies.models import Dependant +from fastapi.dependencies.models import EndpointDependant, LifespanDependant from fastapi.dependencies.utils import ( _should_embed_body_fields, get_body_field, - get_dependant, + get_endpoint_dependant, get_flat_dependant, get_parameterless_sub_dependant, get_typed_return_annotation, @@ -73,7 +73,7 @@ from starlette.routing import ( from starlette.routing import Mount as Mount # noqa from starlette.types import AppType, ASGIApp, Lifespan, Scope from starlette.websockets import WebSocket -from typing_extensions import Annotated, Doc, deprecated +from typing_extensions import Annotated, Doc, assert_never, deprecated def _prepare_response_content( @@ -123,7 +123,7 @@ def _prepare_response_content( return res -def _merge_lifespan_context( +def merge_lifespan_context( original_context: Lifespan[Any], nested_context: Lifespan[Any] ) -> Lifespan[Any]: @asynccontextmanager @@ -202,7 +202,7 @@ async def serialize_response( async def run_endpoint_function( - *, dependant: Dependant, values: Dict[str, Any], is_coroutine: bool + *, dependant: EndpointDependant, values: Dict[str, Any], is_coroutine: bool ) -> Any: # Only called by get_request_handler. Has been split into its own function to # facilitate profiling endpoints, since inner functions are harder to profile. @@ -215,7 +215,7 @@ async def run_endpoint_function( def get_request_handler( - dependant: Dependant, + dependant: EndpointDependant, body_field: Optional[ModelField] = None, status_code: Optional[int] = None, response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse), @@ -358,7 +358,7 @@ def get_request_handler( def get_websocket_app( - dependant: Dependant, + dependant: EndpointDependant, dependency_overrides_provider: Optional[Any] = None, embed_body_fields: bool = False, ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: @@ -400,12 +400,20 @@ class APIWebSocketRoute(routing.WebSocketRoute): self.name = get_name(endpoint) if name is None else name self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) - self.dependant = get_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) for depends in self.dependencies[::-1]: - self.dependant.dependencies.insert( - 0, - get_parameterless_sub_dependant(depends=depends, path=self.path_format), + sub_dependant = get_parameterless_sub_dependant( + depends=depends, + path=self.path_format, + caller=self ) + if depends.dependency_scope == "endpoint": + self.dependant.endpoint_dependencies.insert(0, sub_dependant) + elif depends.dependency_scope == "lifespan": + self.dependant.lifespan_dependencies.insert(0, sub_dependant) + else: + assert_never(depends.dependency_scope) + self._flat_dependant = get_flat_dependant(self.dependant) self._embed_body_fields = _should_embed_body_fields( self._flat_dependant.body_params @@ -424,6 +432,10 @@ class APIWebSocketRoute(routing.WebSocketRoute): child_scope["route"] = self return match, child_scope + @property + def lifespan_dependencies(self) -> List[LifespanDependant]: + return self._flat_dependant.lifespan_dependencies + class APIRoute(routing.Route): def __init__( @@ -549,12 +561,19 @@ class APIRoute(routing.Route): self.response_fields = {} assert callable(endpoint), "An endpoint must be a callable" - self.dependant = get_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) for depends in self.dependencies[::-1]: - self.dependant.dependencies.insert( - 0, - get_parameterless_sub_dependant(depends=depends, path=self.path_format), + sub_dependant = get_parameterless_sub_dependant( + depends=depends, + path=self.path_format, + caller=self.__call__ ) + if depends.dependency_scope == "endpoint": + self.dependant.endpoint_dependencies.insert(0, sub_dependant) + elif depends.dependency_scope == "lifespan": + self.dependant.lifespan_dependencies.insert(0, sub_dependant) + else: + assert_never(depends.dependency_scope) self._flat_dependant = get_flat_dependant(self.dependant) self._embed_body_fields = _should_embed_body_fields( self._flat_dependant.body_params @@ -589,6 +608,10 @@ class APIRoute(routing.Route): child_scope["route"] = self return match, child_scope + @property + def lifespan_dependencies(self) -> List[LifespanDependant]: + return self._flat_dependant.lifespan_dependencies + class APIRouter(routing.Router): """ @@ -1356,7 +1379,7 @@ class APIRouter(routing.Router): self.add_event_handler("startup", handler) for handler in router.on_shutdown: self.add_event_handler("shutdown", handler) - self.lifespan_context = _merge_lifespan_context( + self.lifespan_context = merge_lifespan_context( self.lifespan_context, router.lifespan_context, ) diff --git a/tests/test_lifespan_scoped_dependencies.py b/tests/test_lifespan_scoped_dependencies.py new file mode 100644 index 000000000..40d82f0e0 --- /dev/null +++ b/tests/test_lifespan_scoped_dependencies.py @@ -0,0 +1,703 @@ +from enum import StrEnum, auto +from typing import Any, AsyncGenerator, List, Tuple, TypeVar + +import pytest +from fastapi import ( + APIRouter, + BackgroundTasks, + Body, + Cookie, + Depends, + FastAPI, + File, + Form, + Header, + Path, + Query, +) +from fastapi.exceptions import FastAPIError +from fastapi.params import Security +from fastapi.security import SecurityScopes +from starlette.testclient import TestClient +from typing_extensions import Annotated, Generator, Literal, assert_never + +T = TypeVar('T') + + +class DependencyStyle(StrEnum): + SYNC_FUNCTION = auto() + ASYNC_FUNCTION = auto() + SYNC_GENERATOR = auto() + ASYNC_GENERATOR = auto() + + +class DependencyFactory: + def __init__( + self, + dependency_style: DependencyStyle, *, + should_error: bool = False + ): + self.activation_times = 0 + self.deactivation_times = 0 + self.dependency_style = dependency_style + self._should_error = should_error + + def get_dependency(self): + if self.dependency_style == DependencyStyle.SYNC_FUNCTION: + return self._synchronous_function_dependency + + if self.dependency_style == DependencyStyle.SYNC_GENERATOR: + return self._synchronous_generator_dependency + + if self.dependency_style == DependencyStyle.ASYNC_FUNCTION: + return self._asynchronous_function_dependency + + if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: + return self._asynchronous_generator_dependency + + assert_never(self.dependency_style) + + async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: + self.activation_times += 1 + if self._should_error: + raise ValueError(self.activation_times) + + yield self.activation_times + self.deactivation_times += 1 + + def _synchronous_generator_dependency(self) -> Generator[T, None, None]: + self.activation_times += 1 + if self._should_error: + raise ValueError(self.activation_times) + + yield self.activation_times + self.deactivation_times += 1 + + async def _asynchronous_function_dependency(self) -> T: + self.activation_times += 1 + if self._should_error: + raise ValueError(self.activation_times) + + return self.activation_times + + def _synchronous_function_dependency(self) -> T: + self.activation_times += 1 + if self._should_error: + raise ValueError(self.activation_times) + + return self.activation_times + + +def _expect_correct_amount_of_dependency_activations( + *, + app: FastAPI, + dependency_factory: DependencyFactory, + urls_and_responses: List[Tuple[str, Any]], + expected_activation_times: int +) -> None: + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + with TestClient(app) as client: + assert dependency_factory.activation_times == expected_activation_times + assert dependency_factory.deactivation_times == 0 + + for url, expected_response in urls_and_responses: + response = client.post(url) + response.raise_for_status() + assert response.json() == expected_response + + assert dependency_factory.activation_times == expected_activation_times + assert dependency_factory.deactivation_times == 0 + + assert dependency_factory.activation_times == expected_activation_times + if dependency_factory.dependency_style not in ( + DependencyStyle.SYNC_FUNCTION, + DependencyStyle.ASYNC_FUNCTION + ): + assert dependency_factory.deactivation_times == expected_activation_times + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, use_cache): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + @router.post("/test") + async def endpoint( + dependency: Annotated[None, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + )] + ) -> None: + assert dependency == 1 + return dependency + + if routing_style == "router_endpoint": + app.include_router(router) + + _expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + urls_and_responses=[("/test", 1)] * 2, + expected_activation_times=1 + ) + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_router_dependencies( + dependency_style: DependencyStyle, + routing_style, + use_cache +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + if routing_style == "app": + app = FastAPI(dependencies=[depends]) + + @app.post("/test") + async def endpoint() -> None: + return None + else: + app = FastAPI() + router = APIRouter(dependencies=[depends]) + + @router.post("/test") + async def endpoint() -> None: + return None + + app.include_router(router) + + _expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + urls_and_responses=[("/test", None)] * 2, + expected_activation_times=1 + ) + + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) +def test_dependency_cache_in_same_dependency( + dependency_style: DependencyStyle, + routing_style, + use_cache, + main_dependency_scope: Literal["endpoint", "lifespan"] +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def dependency( + sub_dependency1: Annotated[int, depends], + sub_dependency2: Annotated[int, depends], + ) -> List[int]: + return [sub_dependency1, sub_dependency2] + + @router.post("/test") + async def endpoint( + dependency: Annotated[List[int], Depends( + dependency, + use_cache=use_cache, + dependency_scope=main_dependency_scope, + )] + ) -> List[int]: + return dependency + + if routing_style == "router": + app.include_router(router) + + if use_cache: + _expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 1]), + ("/test", [1, 1]), + ], + dependency_factory=dependency_factory, + expected_activation_times=1 + ) + else: + _expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 2]), + ("/test", [1, 2]), + ], + dependency_factory=dependency_factory, + expected_activation_times=2 + ) + + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_same_endpoint( + dependency_style: DependencyStyle, + routing_style, + use_cache +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + @router.post("/test1") + async def endpoint( + dependency1: Annotated[int, depends], + dependency2: Annotated[int, depends], + dependency3: Annotated[int, Depends(endpoint_dependency)] + ) -> List[int]: + return [dependency1, dependency2, dependency3] + + if routing_style == "router": + app.include_router(router) + + if use_cache: + _expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [1, 1, 1]), + ("/test1", [1, 1, 1]), + ], + dependency_factory=dependency_factory, + expected_activation_times=1 + ) + else: + _expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [1, 2, 3]), + ("/test1", [1, 2, 3]), + ], + dependency_factory=dependency_factory, + expected_activation_times=3 + ) + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_different_endpoints( + dependency_style: DependencyStyle, + routing_style, + use_cache +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + @router.post("/test1") + async def endpoint( + dependency1: Annotated[int, depends], + dependency2: Annotated[int, depends], + dependency3: Annotated[int, Depends(endpoint_dependency)] + ) -> List[int]: + return [dependency1, dependency2, dependency3] + + @router.post("/test2") + async def endpoint2( + dependency1: Annotated[int, depends], + dependency2: Annotated[int, depends], + dependency3: Annotated[int, Depends(endpoint_dependency)] + ) -> List[int]: + return [dependency1, dependency2, dependency3] + + if routing_style == "router": + app.include_router(router) + + if use_cache: + _expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [1, 1, 1]), + ("/test2", [1, 1, 1]), + ("/test1", [1, 1, 1]), + ("/test2", [1, 1, 1]), + ], + dependency_factory=dependency_factory, + expected_activation_times=1 + ) + else: + _expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [1, 2, 3]), + ("/test2", [4, 5, 3]), + ("/test1", [1, 2, 3]), + ("/test2", [4, 5, 3]), + ], + dependency_factory=dependency_factory, + expected_activation_times=5 + ) + +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_no_cached_dependency( + dependency_style: DependencyStyle, + routing_style, +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=False + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + @router.post("/test") + async def endpoint( + dependency: Annotated[int, depends], + ) -> int: + return dependency + + if routing_style == "router": + app.include_router(router) + + _expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + urls_and_responses=[("/test", 1)] * 2, + expected_activation_times=1 + ) + + +@pytest.mark.parametrize("annotation", [ + Annotated[str, Path()], + Annotated[str, Body()], + Annotated[str, Query()], + Annotated[str, Header()], + SecurityScopes, + Annotated[str, Cookie()], + Annotated[str, Form()], + Annotated[str, File()], + BackgroundTasks, +]) +def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( + annotation +): + async def dependency_func(param: annotation) -> None: + yield + + app = FastAPI() + + with pytest.raises(FastAPIError): + @app.post("/test") + async def endpoint( + dependency: Annotated[ + None, Depends(dependency_func, dependency_scope="lifespan")] + ) -> None: + return + + +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( + dependency_style: DependencyStyle +): + dependency_factory = DependencyFactory(dependency_style) + + async def lifespan_scoped_dependency( + param: Annotated[int, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan" + )] + ) -> AsyncGenerator[int, None]: + yield param + + app = FastAPI() + + @app.post("/test") + async def endpoint( + dependency: Annotated[int, Depends( + lifespan_scoped_dependency, + dependency_scope="lifespan" + )] + ) -> int: + return dependency + + _expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + expected_activation_times=1, + urls_and_responses=[("/test", 1)] * 2 + ) + + +@pytest.mark.parametrize("depends_class", [Depends, Security]) +@pytest.mark.parametrize("route_type", [FastAPI.post, FastAPI.websocket], ids=[ + "websocket", "endpoint" +]) +def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( + depends_class, + route_type +): + async def sub_dependency() -> None: + pass + + async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: + yield + + app = FastAPI() + route_decorator = route_type(app, "/test") + + with pytest.raises(FastAPIError): + @route_decorator + async def endpoint(x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] + ) -> None: + return + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_dependencies_must_provide_correct_dependency_scope( + dependency_style: DependencyStyle, + routing_style, + use_cache +): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + with pytest.raises(FastAPIError): + @router.post("/test") + async def endpoint( + dependency: Annotated[None, Depends( + dependency_factory.get_dependency(), + dependency_scope="incorrect", + use_cache=use_cache, + )] + ) -> None: + assert dependency == 1 + return dependency + + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoints_report_incorrect_dependency_scope( + dependency_style: DependencyStyle, + routing_style, + use_cache +): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + # We intentionally change the dependency scope here to bypass the + # validation at the function level. + depends.dependency_scope = "asdad" + + with pytest.raises(FastAPIError): + @router.post("/test") + async def endpoint( + dependency: Annotated[int, depends] + ) -> int: + assert dependency == 1 + return dependency + + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoints_report_uninitialized_dependency( + dependency_style: DependencyStyle, + routing_style, + use_cache +): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + @router.post("/test") + async def endpoint( + dependency: Annotated[int, depends] + ) -> int: + assert dependency == 1 + return dependency + + if routing_style == "router_endpoint": + app.include_router(router) + + with TestClient(app) as client: + dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] + client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {} + + try: + with pytest.raises(FastAPIError): + client.post("/test") + finally: + client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies + + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoints_report_uninitialized_internal_lifespan( + dependency_style: DependencyStyle, + routing_style, + use_cache +): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + @router.post("/test") + async def endpoint( + dependency: Annotated[int, depends] + ) -> int: + assert dependency == 1 + return dependency + + if routing_style == "router_endpoint": + app.include_router(router) + + with TestClient(app) as client: + internal_state = client.app_state["__fastapi__"] + del client.app_state["__fastapi__"] + + try: + with pytest.raises(FastAPIError): + client.post("/test") + finally: + client.app_state["__fastapi__"] = internal_state + + +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: DependencyStyle, routing_style): + dependency_factory= DependencyFactory(dependency_style, should_error=True) + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + + else: + router = APIRouter() + + @router.post("/test") + async def endpoint( + dependency: Annotated[int, depends] + ) -> int: + assert dependency == 1 + return dependency + + if routing_style == "router_endpoint": + app.include_router(router) + + with pytest.raises(ValueError) as exception_info: + with TestClient(app): + pass + + assert exception_info.value.args == (1,) + + +# TODO: Add tests for dependency_overrides +# TODO: Add a websocket equivalent to all tests diff --git a/tests/test_params_repr.py b/tests/test_params_repr.py index bfc7bed09..10f044888 100644 --- a/tests/test_params_repr.py +++ b/tests/test_params_repr.py @@ -1,5 +1,6 @@ from typing import Any, List +import pytest from dirty_equals import IsOneOf from fastapi.params import Body, Cookie, Depends, Header, Param, Path, Query @@ -143,10 +144,16 @@ def test_body_repr_list(): assert repr(Body([])) == "Body([])" -def test_depends_repr(): - assert repr(Depends()) == "Depends(NoneType)" - assert repr(Depends(get_user)) == "Depends(get_user)" - assert repr(Depends(use_cache=False)) == "Depends(NoneType, use_cache=False)" - assert ( - repr(Depends(get_user, use_cache=False)) == "Depends(get_user, use_cache=False)" - ) +@pytest.mark.parametrize(["depends", "expected_repr"], [ + [Depends(), "Depends(NoneType)"], + [Depends(get_user), "Depends(get_user)"], + [Depends(use_cache=False), "Depends(NoneType, use_cache=False)"], + [Depends(get_user, use_cache=False), "Depends(get_user, use_cache=False)"], + + [Depends(dependency_scope="lifespan"), "Depends(NoneType, dependency_scope=\"lifespan\")"], + [Depends(get_user, dependency_scope="lifespan"), "Depends(get_user, dependency_scope=\"lifespan\")"], + [Depends(use_cache=False, dependency_scope="lifespan"), "Depends(NoneType, use_cache=False, dependency_scope=\"lifespan\")"], + [Depends(get_user, use_cache=False, dependency_scope="lifespan"), "Depends(get_user, use_cache=False, dependency_scope=\"lifespan\")"], +]) +def test_depends_repr(depends, expected_repr): + assert repr(depends) == expected_repr diff --git a/tests/test_router_events.py b/tests/test_router_events.py index dd7ff3314..8289a7301 100644 --- a/tests/test_router_events.py +++ b/tests/test_router_events.py @@ -199,6 +199,9 @@ def test_router_nested_lifespan_state_overriding_by_parent() -> None: "app_specific": True, "router_specific": True, "overridden": "app", + "__fastapi__": { + "lifespan_scoped_dependencies": {} + }, } @@ -216,7 +219,11 @@ def test_merged_no_return_lifespans_return_none() -> None: app.include_router(router) with TestClient(app) as client: - assert not client.app_state + assert client.app_state == { + "__fastapi__": { + "lifespan_scoped_dependencies": {} + } + } def test_merged_mixed_state_lifespans() -> None: @@ -239,4 +246,9 @@ def test_merged_mixed_state_lifespans() -> None: app.include_router(router) with TestClient(app) as client: - assert client.app_state == {"router": True} + assert client.app_state == { + "router": True, + "__fastapi__": { + "lifespan_scoped_dependencies": {} + } + }