diff --git a/fastapi/applications.py b/fastapi/applications.py index 625690a74..7da120a34 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -946,8 +946,13 @@ class FastAPI(Starlette): # Since we always use a lifespan, starlette will no longer run event # handlers which are defined in the scope of the application. # We therefore need to call them ourselves. - self._on_startup = on_startup or [] - self._on_shutdown = on_shutdown or [] + if on_startup is None: + on_startup = [] + + if on_shutdown is None: + on_shutdown = [] + self._on_startup = list(on_startup) + self._on_shutdown = list(on_shutdown) self.router: routing.APIRouter = routing.APIRouter( routes=routes, @@ -982,7 +987,7 @@ class FastAPI(Starlette): self.setup() @asynccontextmanager - async def _internal_lifespan(self) -> AsyncGenerator[dict[str, Any], None]: + async def _internal_lifespan(self) -> AsyncGenerator[Dict[str, Any], None]: async with AsyncExitStack() as exit_stack: lifespan_scoped_dependencies = await resolve_lifespan_dependants( app=self, diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 471f9c402..b68f0339c 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast from fastapi._compat import ModelField from fastapi.security.base import SecurityBase @@ -12,22 +12,28 @@ class SecurityRequirement: scopes: Optional[Sequence[str]] = None -LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], str], Callable[..., Any]] +LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], Union[str, int]], Callable[..., Any]] @dataclass class LifespanDependant: + call: Callable[..., Any] caller: Callable[..., Any] dependencies: List["LifespanDependant"] = field(default_factory=list) name: Optional[str] = None - call: Optional[Callable[..., Any]] = None use_cache: bool = True + index: Optional[int] = None cache_key: LifespanDependantCacheKey = field(init=False) def __post_init__(self) -> None: if self.use_cache: self.cache_key = self.call - else: + elif self.name is not None: self.cache_key = (self.caller, self.name) + else: + assert self.index is not None, ( + "Lifespan dependency must have an associated name or index." + ) + self.cache_key = (self.caller, self.index) EndpointDependantCacheKey: TypeAlias = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] @@ -39,6 +45,7 @@ class EndpointDependant: name: Optional[str] = None call: Optional[Callable[..., Any]] = None use_cache: bool = True + index: Optional[int] = None cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field( init=False) path_params: List[ModelField] = field(default_factory=list) @@ -62,7 +69,16 @@ class EndpointDependant: # Kept for backwards compatibility @property def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]: - return tuple(self.endpoint_dependencies + self.lifespan_dependencies) + lifespan_dependencies = cast( + List[Union[EndpointDependant, LifespanDependant]], + self.lifespan_dependencies + ) + endpoint_dependencies = cast( + List[Union[EndpointDependant, LifespanDependant]], + self.endpoint_dependencies + ) + + return tuple(lifespan_dependencies + endpoint_dependencies) # Kept for backwards compatibility Dependant = EndpointDependant diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 0e9cfe244..46a76e8c4 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -54,11 +54,16 @@ from fastapi.concurrency import ( from fastapi.dependencies.models import ( CacheKey, EndpointDependant, + EndpointDependantCacheKey, LifespanDependant, LifespanDependantCacheKey, SecurityRequirement, ) -from fastapi.exceptions import FastAPIError +from fastapi.exceptions import ( + DependencyScopeConflict, + InvalidDependencyScope, + UninitializedLifespanDependency, +) from fastapi.logger import logger from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes @@ -78,7 +83,7 @@ from starlette.datastructures import ( from starlette.requests import HTTPConnection, Request from starlette.responses import Response from starlette.websockets import WebSocket -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import Annotated, assert_never, get_args, get_origin multipart_not_installed_error = ( 'Form data requires "python-multipart" to be installed. \n' @@ -137,7 +142,8 @@ def get_parameterless_sub_dependant( *, depends: params.Depends, path: str, - caller: Callable[..., Any] + caller: Callable[..., Any], + index: int ) -> Union[EndpointDependant, LifespanDependant]: assert callable( depends.dependency @@ -146,7 +152,8 @@ def get_parameterless_sub_dependant( depends=depends, dependency=depends.dependency, path=path, - caller=caller + caller=caller, + index=index ) @@ -158,13 +165,15 @@ def get_sub_dependant( caller: Callable[..., Any], name: Optional[str] = None, security_scopes: Optional[List[str]] = None, + index: Optional[int] = None, ) -> Union[EndpointDependant, LifespanDependant]: if depends.dependency_scope == "lifespan": return get_lifespan_dependant( caller=caller, - call=depends.dependency, + call=dependency, name=name, - use_cache=depends.use_cache + use_cache=depends.use_cache, + index=index ) elif depends.dependency_scope == "endpoint": security_requirement = None @@ -185,14 +194,15 @@ def get_sub_dependant( name=name, security_scopes=security_scopes, use_cache=depends.use_cache, + index=index ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) return sub_dependant else: - raise ValueError( - f"Dependency {name} of {caller} has an invalid " - f"sub-dependency scope: {depends.dependency_scope}" + raise InvalidDependencyScope( + f"Dependency \"{name}\" of {caller} has an invalid " + f"scope: \"{depends.dependency_scope}\"" ) @@ -292,6 +302,7 @@ def get_lifespan_dependant( call: Callable[..., Any], name: Optional[str] = None, use_cache: bool = True, + index: Optional[int] = None ) -> LifespanDependant: dependency_signature = get_typed_signature(call) signature_params = dependency_signature.parameters @@ -299,7 +310,8 @@ def get_lifespan_dependant( call=call, name=name, use_cache=use_cache, - caller=caller + caller=caller, + index=index ) for param_name, param in signature_params.items(): param_details = analyze_param( @@ -309,17 +321,23 @@ def get_lifespan_dependant( is_path_param=False, ) if param_details.depends is None: - raise FastAPIError( - f"Lifespan dependency {dependant.name} was defined with an " - f"invalid argument {param_name}. Lifespan dependencies may " - f"only use other lifespan dependencies as arguments.") + raise DependencyScopeConflict( + f"Lifespan scoped dependency \"{dependant.name}\" was defined " + f"with an invalid argument: \"{param_name}\" which is " + f"\"endpoint\" scoped. Lifespan scoped dependencies may only " + f"use lifespan scoped sub-dependencies.") if param_details.depends.dependency_scope != "lifespan": - raise FastAPIError( - "Lifespan dependency may not use " - "sub-dependencies of other scopes." + raise DependencyScopeConflict( + f"Lifespan scoped dependency {dependant.name} was defined with the " + f"sub-dependency \"{param_name}\" which is " + f"\"{param_details.depends.dependency_scope}\" scoped. " + f"Lifespan scoped dependencies may only use lifespan scoped " + f"sub-dependencies." ) + assert param_details.depends.dependency is not None + sub_dependant = get_lifespan_dependant( name=param_name, call=param_details.depends.dependency, @@ -339,6 +357,7 @@ def get_endpoint_dependant( name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, + index: Optional[int] = None ) -> EndpointDependant: path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) @@ -349,6 +368,7 @@ def get_endpoint_dependant( path=path, security_scopes=security_scopes, use_cache=use_cache, + index=index ) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names @@ -359,28 +379,19 @@ def get_endpoint_dependant( is_path_param=is_path_param, ) if param_details.depends is not None: - if param_details.depends.dependency_scope == "endpoint": - sub_dependant = get_param_sub_dependant( - param_name=param_name, - depends=param_details.depends, - path=path, - security_scopes=security_scopes, - caller=call, - ) + sub_dependant = get_param_sub_dependant( + param_name=param_name, + depends=param_details.depends, + path=path, + security_scopes=security_scopes, + caller=call, + ) + if isinstance(sub_dependant, EndpointDependant): dependant.endpoint_dependencies.append(sub_dependant) - elif param_details.depends.dependency_scope == "lifespan": - sub_dependant = get_lifespan_dependant( - caller=call, - call=param_details.depends.dependency, - name=param_name, - use_cache=param_details.depends.use_cache, - ) + elif isinstance(sub_dependant, LifespanDependant): dependant.lifespan_dependencies.append(sub_dependant) else: - raise FastAPIError( - f"Dependency \"{param_name}\" of `{call}` has an invalid " - f"sub-dependency scope: \"{param_details.depends.dependency_scope}\"" - ) + assert_never(sub_dependant) continue if add_non_field_param_to_dependency( param_name=param_name, @@ -652,7 +663,7 @@ async def solve_generator( @dataclass class SolvedLifespanDependant: value: Any - dependency_cache: Dict[Callable[..., Any], Any] + dependency_cache: Dict[LifespanDependantCacheKey, Any] async def solve_lifespan_dependant( @@ -669,35 +680,33 @@ async def solve_lifespan_dependant( dependency_cache=dependency_cache, ) - dependency_arguments: Dict[str, Any] = {} - sub_dependant: LifespanDependant - for sub_dependant in dependant.dependencies: - sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) - sub_dependant.cache_key = cast( - Callable[..., Any], sub_dependant.cache_key + call = dependant.call + dependant_to_solve = dependant + if ( + dependency_overrides_provider + and dependency_overrides_provider.dependency_overrides + ): + call = getattr( + dependency_overrides_provider, + "dependency_overrides", + {} + ).get(dependant.call, dependant.call) + dependant_to_solve = get_lifespan_dependant( + caller=dependant.caller, + call=call, + name=dependant.name, + use_cache=dependant.use_cache, + index=dependant.index, ) + + dependency_arguments: Dict[str, Any] = {} + for sub_dependant in dependant_to_solve.dependencies: assert sub_dependant.name, ( "Lifespan scoped dependencies should not be able to have " "subdependencies with no name" ) - - sub_dependant_to_solve = sub_dependant - if ( - dependency_overrides_provider - and dependency_overrides_provider.dependency_overrides - ): - original_call = sub_dependant.call - call = getattr( - dependency_overrides_provider, "dependency_overrides", {} - ).get(original_call, original_call) - sub_dependant_to_solve = get_lifespan_dependant( - call=call, - name=sub_dependant.name, - caller=dependant.call - ) - solved_sub_dependant = await solve_lifespan_dependant( - dependant=sub_dependant_to_solve, + dependant=sub_dependant, dependency_overrides_provider=dependency_overrides_provider, dependency_cache=dependency_cache, async_exit_stack=async_exit_stack, @@ -705,16 +714,16 @@ async def solve_lifespan_dependant( dependency_cache.update(solved_sub_dependant.dependency_cache) dependency_arguments[sub_dependant.name] = solved_sub_dependant.value - if is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call): + if is_gen_callable(call) or is_async_gen_callable(call): value = await solve_generator( - call=dependant.call, + call=call, stack=async_exit_stack, sub_values=dependency_arguments ) - elif is_coroutine_callable(dependant.call): - value = await dependant.call(**dependency_arguments) + elif is_coroutine_callable(call): + value = await call(**dependency_arguments) else: - value = await run_in_threadpool(dependant.call, **dependency_arguments) + value = await run_in_threadpool(call, **dependency_arguments) if dependant.cache_key not in dependency_cache: dependency_cache[dependant.cache_key] = value @@ -731,7 +740,7 @@ class SolvedDependency: errors: List[Any] background_tasks: Optional[StarletteBackgroundTasks] response: Response - dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] + dependency_cache: Dict[EndpointDependantCacheKey, Any] async def solve_dependencies( @@ -742,33 +751,34 @@ async def solve_dependencies( background_tasks: Optional[StarletteBackgroundTasks] = None, response: Optional[Response] = None, dependency_overrides_provider: Optional[Any] = None, - dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, + dependency_cache: Optional[Dict[EndpointDependantCacheKey, Any]] = None, async_exit_stack: AsyncExitStack, embed_body_fields: bool, ) -> SolvedDependency: values: Dict[str, Any] = {} errors: List[Any] = [] - for sub_dependant in dependant.lifespan_dependencies: - if sub_dependant.name is None: + for lifespan_sub_dependant in dependant.lifespan_dependencies: + if lifespan_sub_dependant.name is None: continue + try: lifespan_scoped_dependencies = request.state.__fastapi__[ "lifespan_scoped_dependencies"] - except AttributeError as e: - raise FastAPIError( - "FastAPI's internal lifespan was not initialized" + except (AttributeError, KeyError) as e: + raise UninitializedLifespanDependency( + "FastAPI's internal lifespan was not initialized correctly." ) from e try: - value = lifespan_scoped_dependencies[sub_dependant.cache_key] + value = lifespan_scoped_dependencies[lifespan_sub_dependant.cache_key] except KeyError as e: - raise FastAPIError( - f"Dependency {sub_dependant.name} of {dependant.call} " - f"was not initialized." + raise UninitializedLifespanDependency( + f"Dependency \"{lifespan_sub_dependant.name}\" of " + f"`{dependant.call}` was not initialized correctly." ) from e - values[sub_dependant.name] = value + values[lifespan_sub_dependant.name] = value if response is None: response = Response() diff --git a/fastapi/exceptions.py b/fastapi/exceptions.py index 44d4ada86..95fd60477 100644 --- a/fastapi/exceptions.py +++ b/fastapi/exceptions.py @@ -146,6 +146,22 @@ class FastAPIError(RuntimeError): """ +class DependencyError(FastAPIError): + pass + + +class InvalidDependencyScope(DependencyError): + pass + + +class DependencyScopeConflict(DependencyError): + pass + + +class UninitializedLifespanDependency(DependencyError): + pass + + class ValidationException(Exception): def __init__(self, errors: Sequence[Any]) -> None: self._errors = errors diff --git a/fastapi/lifespan.py b/fastapi/lifespan.py index 184c943d7..d725a5e30 100644 --- a/fastapi/lifespan.py +++ b/fastapi/lifespan.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List from fastapi.dependencies.models import LifespanDependant, LifespanDependantCacheKey from fastapi.dependencies.utils import solve_lifespan_dependant -from fastapi.routing import APIRoute +from fastapi.routing import APIRoute, APIWebSocketRoute if TYPE_CHECKING: from fastapi import FastAPI @@ -14,7 +14,7 @@ if TYPE_CHECKING: def _get_lifespan_dependants(app: FastAPI) -> List[LifespanDependant]: lifespan_dependants_cache: Dict[LifespanDependantCacheKey, LifespanDependant] = {} for route in app.router.routes: - if not isinstance(route, APIRoute): + if not isinstance(route, (APIWebSocketRoute, APIRoute)): continue for sub_dependant in route.lifespan_dependencies: diff --git a/fastapi/routing.py b/fastapi/routing.py index 50b51a352..e11edaa13 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -401,15 +401,18 @@ class APIWebSocketRoute(routing.WebSocketRoute): self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) - for depends in self.dependencies[::-1]: + for i, depends in list(enumerate(self.dependencies))[::-1]: sub_dependant = get_parameterless_sub_dependant( depends=depends, path=self.path_format, - caller=self + caller=self.__call__, + index=i ) if depends.dependency_scope == "endpoint": + assert isinstance(sub_dependant, EndpointDependant) self.dependant.endpoint_dependencies.insert(0, sub_dependant) elif depends.dependency_scope == "lifespan": + assert isinstance(sub_dependant, LifespanDependant) self.dependant.lifespan_dependencies.insert(0, sub_dependant) else: assert_never(depends.dependency_scope) @@ -564,15 +567,18 @@ class APIRoute(routing.Route): assert callable(endpoint), "An endpoint must be a callable" self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) - for depends in self.dependencies[::-1]: + for i, depends in list(enumerate(self.dependencies))[::-1]: sub_dependant = get_parameterless_sub_dependant( depends=depends, path=self.path_format, - caller=self.__call__ + caller=self.__call__, + index=i ) if depends.dependency_scope == "endpoint": + assert isinstance(sub_dependant, EndpointDependant) self.dependant.endpoint_dependencies.insert(0, sub_dependant) elif depends.dependency_scope == "lifespan": + assert isinstance(sub_dependant, LifespanDependant) self.dependant.lifespan_dependencies.insert(0, sub_dependant) else: assert_never(depends.dependency_scope) diff --git a/tests/test_lifespan_scoped_dependencies.py b/tests/test_lifespan_scoped_dependencies.py deleted file mode 100644 index 40d82f0e0..000000000 --- a/tests/test_lifespan_scoped_dependencies.py +++ /dev/null @@ -1,703 +0,0 @@ -from enum import StrEnum, auto -from typing import Any, AsyncGenerator, List, Tuple, TypeVar - -import pytest -from fastapi import ( - APIRouter, - BackgroundTasks, - Body, - Cookie, - Depends, - FastAPI, - File, - Form, - Header, - Path, - Query, -) -from fastapi.exceptions import FastAPIError -from fastapi.params import Security -from fastapi.security import SecurityScopes -from starlette.testclient import TestClient -from typing_extensions import Annotated, Generator, Literal, assert_never - -T = TypeVar('T') - - -class DependencyStyle(StrEnum): - SYNC_FUNCTION = auto() - ASYNC_FUNCTION = auto() - SYNC_GENERATOR = auto() - ASYNC_GENERATOR = auto() - - -class DependencyFactory: - def __init__( - self, - dependency_style: DependencyStyle, *, - should_error: bool = False - ): - self.activation_times = 0 - self.deactivation_times = 0 - self.dependency_style = dependency_style - self._should_error = should_error - - def get_dependency(self): - if self.dependency_style == DependencyStyle.SYNC_FUNCTION: - return self._synchronous_function_dependency - - if self.dependency_style == DependencyStyle.SYNC_GENERATOR: - return self._synchronous_generator_dependency - - if self.dependency_style == DependencyStyle.ASYNC_FUNCTION: - return self._asynchronous_function_dependency - - if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: - return self._asynchronous_generator_dependency - - assert_never(self.dependency_style) - - async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: - self.activation_times += 1 - if self._should_error: - raise ValueError(self.activation_times) - - yield self.activation_times - self.deactivation_times += 1 - - def _synchronous_generator_dependency(self) -> Generator[T, None, None]: - self.activation_times += 1 - if self._should_error: - raise ValueError(self.activation_times) - - yield self.activation_times - self.deactivation_times += 1 - - async def _asynchronous_function_dependency(self) -> T: - self.activation_times += 1 - if self._should_error: - raise ValueError(self.activation_times) - - return self.activation_times - - def _synchronous_function_dependency(self) -> T: - self.activation_times += 1 - if self._should_error: - raise ValueError(self.activation_times) - - return self.activation_times - - -def _expect_correct_amount_of_dependency_activations( - *, - app: FastAPI, - dependency_factory: DependencyFactory, - urls_and_responses: List[Tuple[str, Any]], - expected_activation_times: int -) -> None: - assert dependency_factory.activation_times == 0 - assert dependency_factory.deactivation_times == 0 - with TestClient(app) as client: - assert dependency_factory.activation_times == expected_activation_times - assert dependency_factory.deactivation_times == 0 - - for url, expected_response in urls_and_responses: - response = client.post(url) - response.raise_for_status() - assert response.json() == expected_response - - assert dependency_factory.activation_times == expected_activation_times - assert dependency_factory.deactivation_times == 0 - - assert dependency_factory.activation_times == expected_activation_times - if dependency_factory.dependency_style not in ( - DependencyStyle.SYNC_FUNCTION, - DependencyStyle.ASYNC_FUNCTION - ): - assert dependency_factory.deactivation_times == expected_activation_times - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, use_cache): - dependency_factory= DependencyFactory(dependency_style) - - app = FastAPI() - - if routing_style == "app_endpoint": - router = app - else: - router = APIRouter() - - @router.post("/test") - async def endpoint( - dependency: Annotated[None, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache, - )] - ) -> None: - assert dependency == 1 - return dependency - - if routing_style == "router_endpoint": - app.include_router(router) - - _expect_correct_amount_of_dependency_activations( - app=app, - dependency_factory=dependency_factory, - urls_and_responses=[("/test", 1)] * 2, - expected_activation_times=1 - ) - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app", "router"]) -def test_router_dependencies( - dependency_style: DependencyStyle, - routing_style, - use_cache -): - dependency_factory= DependencyFactory(dependency_style) - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache - ) - - if routing_style == "app": - app = FastAPI(dependencies=[depends]) - - @app.post("/test") - async def endpoint() -> None: - return None - else: - app = FastAPI() - router = APIRouter(dependencies=[depends]) - - @router.post("/test") - async def endpoint() -> None: - return None - - app.include_router(router) - - _expect_correct_amount_of_dependency_activations( - app=app, - dependency_factory=dependency_factory, - urls_and_responses=[("/test", None)] * 2, - expected_activation_times=1 - ) - - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app", "router"]) -@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) -def test_dependency_cache_in_same_dependency( - dependency_style: DependencyStyle, - routing_style, - use_cache, - main_dependency_scope: Literal["endpoint", "lifespan"] -): - dependency_factory= DependencyFactory(dependency_style) - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache - ) - - app = FastAPI() - - if routing_style == "app": - router = app - - else: - router = APIRouter() - - async def dependency( - sub_dependency1: Annotated[int, depends], - sub_dependency2: Annotated[int, depends], - ) -> List[int]: - return [sub_dependency1, sub_dependency2] - - @router.post("/test") - async def endpoint( - dependency: Annotated[List[int], Depends( - dependency, - use_cache=use_cache, - dependency_scope=main_dependency_scope, - )] - ) -> List[int]: - return dependency - - if routing_style == "router": - app.include_router(router) - - if use_cache: - _expect_correct_amount_of_dependency_activations( - app=app, - urls_and_responses=[ - ("/test", [1, 1]), - ("/test", [1, 1]), - ], - dependency_factory=dependency_factory, - expected_activation_times=1 - ) - else: - _expect_correct_amount_of_dependency_activations( - app=app, - urls_and_responses=[ - ("/test", [1, 2]), - ("/test", [1, 2]), - ], - dependency_factory=dependency_factory, - expected_activation_times=2 - ) - - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app", "router"]) -def test_dependency_cache_in_same_endpoint( - dependency_style: DependencyStyle, - routing_style, - use_cache -): - dependency_factory= DependencyFactory(dependency_style) - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache - ) - - app = FastAPI() - - if routing_style == "app": - router = app - - else: - router = APIRouter() - - async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: - return dependency3 - - @router.post("/test1") - async def endpoint( - dependency1: Annotated[int, depends], - dependency2: Annotated[int, depends], - dependency3: Annotated[int, Depends(endpoint_dependency)] - ) -> List[int]: - return [dependency1, dependency2, dependency3] - - if routing_style == "router": - app.include_router(router) - - if use_cache: - _expect_correct_amount_of_dependency_activations( - app=app, - urls_and_responses=[ - ("/test1", [1, 1, 1]), - ("/test1", [1, 1, 1]), - ], - dependency_factory=dependency_factory, - expected_activation_times=1 - ) - else: - _expect_correct_amount_of_dependency_activations( - app=app, - urls_and_responses=[ - ("/test1", [1, 2, 3]), - ("/test1", [1, 2, 3]), - ], - dependency_factory=dependency_factory, - expected_activation_times=3 - ) - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app", "router"]) -def test_dependency_cache_in_different_endpoints( - dependency_style: DependencyStyle, - routing_style, - use_cache -): - dependency_factory= DependencyFactory(dependency_style) - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache - ) - - app = FastAPI() - - if routing_style == "app": - router = app - - else: - router = APIRouter() - - async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: - return dependency3 - - @router.post("/test1") - async def endpoint( - dependency1: Annotated[int, depends], - dependency2: Annotated[int, depends], - dependency3: Annotated[int, Depends(endpoint_dependency)] - ) -> List[int]: - return [dependency1, dependency2, dependency3] - - @router.post("/test2") - async def endpoint2( - dependency1: Annotated[int, depends], - dependency2: Annotated[int, depends], - dependency3: Annotated[int, Depends(endpoint_dependency)] - ) -> List[int]: - return [dependency1, dependency2, dependency3] - - if routing_style == "router": - app.include_router(router) - - if use_cache: - _expect_correct_amount_of_dependency_activations( - app=app, - urls_and_responses=[ - ("/test1", [1, 1, 1]), - ("/test2", [1, 1, 1]), - ("/test1", [1, 1, 1]), - ("/test2", [1, 1, 1]), - ], - dependency_factory=dependency_factory, - expected_activation_times=1 - ) - else: - _expect_correct_amount_of_dependency_activations( - app=app, - urls_and_responses=[ - ("/test1", [1, 2, 3]), - ("/test2", [4, 5, 3]), - ("/test1", [1, 2, 3]), - ("/test2", [4, 5, 3]), - ], - dependency_factory=dependency_factory, - expected_activation_times=5 - ) - -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app", "router"]) -def test_no_cached_dependency( - dependency_style: DependencyStyle, - routing_style, -): - dependency_factory= DependencyFactory(dependency_style) - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=False - ) - - app = FastAPI() - - if routing_style == "app": - router = app - - else: - router = APIRouter() - - @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends], - ) -> int: - return dependency - - if routing_style == "router": - app.include_router(router) - - _expect_correct_amount_of_dependency_activations( - app=app, - dependency_factory=dependency_factory, - urls_and_responses=[("/test", 1)] * 2, - expected_activation_times=1 - ) - - -@pytest.mark.parametrize("annotation", [ - Annotated[str, Path()], - Annotated[str, Body()], - Annotated[str, Query()], - Annotated[str, Header()], - SecurityScopes, - Annotated[str, Cookie()], - Annotated[str, Form()], - Annotated[str, File()], - BackgroundTasks, -]) -def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( - annotation -): - async def dependency_func(param: annotation) -> None: - yield - - app = FastAPI() - - with pytest.raises(FastAPIError): - @app.post("/test") - async def endpoint( - dependency: Annotated[ - None, Depends(dependency_func, dependency_scope="lifespan")] - ) -> None: - return - - -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( - dependency_style: DependencyStyle -): - dependency_factory = DependencyFactory(dependency_style) - - async def lifespan_scoped_dependency( - param: Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - )] - ) -> AsyncGenerator[int, None]: - yield param - - app = FastAPI() - - @app.post("/test") - async def endpoint( - dependency: Annotated[int, Depends( - lifespan_scoped_dependency, - dependency_scope="lifespan" - )] - ) -> int: - return dependency - - _expect_correct_amount_of_dependency_activations( - app=app, - dependency_factory=dependency_factory, - expected_activation_times=1, - urls_and_responses=[("/test", 1)] * 2 - ) - - -@pytest.mark.parametrize("depends_class", [Depends, Security]) -@pytest.mark.parametrize("route_type", [FastAPI.post, FastAPI.websocket], ids=[ - "websocket", "endpoint" -]) -def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( - depends_class, - route_type -): - async def sub_dependency() -> None: - pass - - async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: - yield - - app = FastAPI() - route_decorator = route_type(app, "/test") - - with pytest.raises(FastAPIError): - @route_decorator - async def endpoint(x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] - ) -> None: - return - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_dependencies_must_provide_correct_dependency_scope( - dependency_style: DependencyStyle, - routing_style, - use_cache -): - dependency_factory= DependencyFactory(dependency_style) - - app = FastAPI() - - if routing_style == "app_endpoint": - router = app - else: - router = APIRouter() - - with pytest.raises(FastAPIError): - @router.post("/test") - async def endpoint( - dependency: Annotated[None, Depends( - dependency_factory.get_dependency(), - dependency_scope="incorrect", - use_cache=use_cache, - )] - ) -> None: - assert dependency == 1 - return dependency - - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_endpoints_report_incorrect_dependency_scope( - dependency_style: DependencyStyle, - routing_style, - use_cache -): - dependency_factory= DependencyFactory(dependency_style) - - app = FastAPI() - - if routing_style == "app_endpoint": - router = app - else: - router = APIRouter() - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache, - ) - # We intentionally change the dependency scope here to bypass the - # validation at the function level. - depends.dependency_scope = "asdad" - - with pytest.raises(FastAPIError): - @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: - assert dependency == 1 - return dependency - - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_endpoints_report_uninitialized_dependency( - dependency_style: DependencyStyle, - routing_style, - use_cache -): - dependency_factory= DependencyFactory(dependency_style) - - app = FastAPI() - - if routing_style == "app_endpoint": - router = app - else: - router = APIRouter() - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache, - ) - - @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: - assert dependency == 1 - return dependency - - if routing_style == "router_endpoint": - app.include_router(router) - - with TestClient(app) as client: - dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] - client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {} - - try: - with pytest.raises(FastAPIError): - client.post("/test") - finally: - client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies - - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_endpoints_report_uninitialized_internal_lifespan( - dependency_style: DependencyStyle, - routing_style, - use_cache -): - dependency_factory= DependencyFactory(dependency_style) - - app = FastAPI() - - if routing_style == "app_endpoint": - router = app - else: - router = APIRouter() - - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache, - ) - - @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: - assert dependency == 1 - return dependency - - if routing_style == "router_endpoint": - app.include_router(router) - - with TestClient(app) as client: - internal_state = client.app_state["__fastapi__"] - del client.app_state["__fastapi__"] - - try: - with pytest.raises(FastAPIError): - client.post("/test") - finally: - client.app_state["__fastapi__"] = internal_state - - -@pytest.mark.parametrize("use_cache", [True, False]) -@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) -@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: DependencyStyle, routing_style): - dependency_factory= DependencyFactory(dependency_style, should_error=True) - depends = Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan", - use_cache=use_cache, - ) - - app = FastAPI() - - if routing_style == "app_endpoint": - router = app - - else: - router = APIRouter() - - @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: - assert dependency == 1 - return dependency - - if routing_style == "router_endpoint": - app.include_router(router) - - with pytest.raises(ValueError) as exception_info: - with TestClient(app): - pass - - assert exception_info.value.args == (1,) - - -# TODO: Add tests for dependency_overrides -# TODO: Add a websocket equivalent to all tests diff --git a/tests/test_lifespan_scoped_dependencies/__init__.py b/tests/test_lifespan_scoped_dependencies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py new file mode 100644 index 000000000..430765aae --- /dev/null +++ b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py @@ -0,0 +1,634 @@ +from typing import Any, AsyncGenerator, List, Tuple + +import pytest +from fastapi import ( + APIRouter, + BackgroundTasks, + Body, + Cookie, + Depends, + FastAPI, + File, + Form, + Header, + Path, + Query, +) +from fastapi.exceptions import DependencyScopeConflict +from fastapi.params import Security +from fastapi.security import SecurityScopes +from fastapi.testclient import TestClient +from typing_extensions import Annotated, Literal + +from tests.test_lifespan_scoped_dependencies.testing_utilities import ( + DependencyFactory, + DependencyStyle, + IntentionallyBadDependency, + create_endpoint_0_annotations, + create_endpoint_1_annotation, + create_endpoint_3_annotations, + use_endpoint, + use_websocket, +) + + +def expect_correct_amount_of_dependency_activations( + *, + app: FastAPI, + dependency_factory: DependencyFactory, + override_dependency_factory: DependencyFactory, + urls_and_responses: List[Tuple[str, Any]], + expected_activation_times: int, + is_websocket: bool +) -> None: + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + assert override_dependency_factory.activation_times == 0 + assert override_dependency_factory.deactivation_times == 0 + + with TestClient(app) as client: + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + assert override_dependency_factory.activation_times == expected_activation_times + assert override_dependency_factory.deactivation_times == 0 + + for url, expected_response in urls_and_responses: + if is_websocket: + response = use_websocket(client, url) + else: + response = use_endpoint(client, url) + + assert response == expected_response + + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + assert override_dependency_factory.activation_times == expected_activation_times + assert override_dependency_factory.deactivation_times == 0 + + assert dependency_factory.activation_times == 0 + assert override_dependency_factory.activation_times == expected_activation_times + if dependency_factory.dependency_style not in ( + DependencyStyle.SYNC_FUNCTION, + DependencyStyle.ASYNC_FUNCTION + ): + assert dependency_factory.deactivation_times == 0 + assert override_dependency_factory.deactivation_times == expected_activation_times + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoint_dependencies( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory( + dependency_style, + value_offset=10 + ) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + None, + Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + ], + expected_value=11 + ) + if routing_style == "router_endpoint": + app.include_router(router) + + app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + urls_and_responses=[("/test", 11)] * 2, + expected_activation_times=1, + is_websocket=is_websocket + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("dependency_duplication", [1, 2]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_router_dependencies( + dependency_style: DependencyStyle, + routing_style, + use_cache, + dependency_duplication, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory( + dependency_style, + value_offset=10 + ) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + if routing_style == "app": + app = FastAPI(dependencies=[depends] * dependency_duplication) + + create_endpoint_0_annotations( + router=app, + path="/test", + is_websocket=is_websocket + ) + else: + app = FastAPI() + router = APIRouter(dependencies=[depends] * dependency_duplication) + + create_endpoint_0_annotations( + router=router, + path="/test", + is_websocket=is_websocket + ) + + app.include_router(router) + + app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + urls_and_responses=[("/test", None)] * 2, + expected_activation_times=1 if use_cache else dependency_duplication, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) +def test_dependency_cache_in_same_dependency( + dependency_style: DependencyStyle, + routing_style, + use_cache, + main_dependency_scope: Literal["endpoint", "lifespan"], + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory( + dependency_style, + value_offset=10 + ) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def dependency( + sub_dependency1: Annotated[int, depends], + sub_dependency2: Annotated[int, depends], + ) -> List[int]: + return [sub_dependency1, sub_dependency2] + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[List[int], Depends( + dependency, + use_cache=use_cache, + dependency_scope=main_dependency_scope, + )] + ) + + if routing_style == "router": + app.include_router(router) + + app.dependency_overrides[ + dependency_factory.get_dependency() + ] = override_dependency_factory.get_dependency() + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [11, 11]), + ("/test", [11, 11]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [11, 12]), + ("/test", [11, 12]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=2, + is_websocket=is_websocket + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_same_endpoint( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory( + dependency_style, + value_offset=10 + ) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + create_endpoint_3_annotations( + router=router, + path="/test1", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)], + ) + + if routing_style == "router": + app.include_router(router) + + app.dependency_overrides[ + dependency_factory.get_dependency() + ] = override_dependency_factory.get_dependency() + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [11, 11, 11]), + ("/test1", [11, 11, 11]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [11, 12, 13]), + ("/test1", [11, 12, 13]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=3, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_different_endpoints( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory( + dependency_style, + value_offset=10 + ) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + create_endpoint_3_annotations( + router=router, + path="/test1", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)], + ) + + create_endpoint_3_annotations( + router=router, + path="/test2", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)], + ) + + if routing_style == "router": + app.include_router(router) + + app.dependency_overrides[ + dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [11, 11, 11]), + ("/test2", [11, 11, 11]), + ("/test1", [11, 11, 11]), + ("/test2", [11, 11, 11]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [11, 12, 13]), + ("/test2", [14, 15, 13]), + ("/test1", [11, 12, 13]), + ("/test2", [14, 15, 13]), + ], + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=5, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_no_cached_dependency( + dependency_style: DependencyStyle, + routing_style, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory( + dependency_style, + value_offset=10 + ) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=False + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + ) + + if routing_style == "router": + app.include_router(router) + + app.dependency_overrides[ + dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + urls_and_responses=[("/test", 11)] * 2, + expected_activation_times=1, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("annotation", [ + Annotated[str, Path()], + Annotated[str, Body()], + Annotated[str, Query()], + Annotated[str, Header()], + SecurityScopes, + Annotated[str, Cookie()], + Annotated[str, Form()], + Annotated[str, File()], + BackgroundTasks, +]) +def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( + annotation, + is_websocket +): + async def dependency_func() -> None: + yield + + async def override_dependency_func(param: annotation) -> None: + yield + + app = FastAPI() + app.dependency_overrides[dependency_func] = override_dependency_func + + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[None, + Depends(dependency_func, dependency_scope="lifespan") + ] + ) + + with pytest.raises(DependencyScopeConflict): + with TestClient(app): + pass + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies( + dependency_style: DependencyStyle, + is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory( + dependency_style, + value_offset=10 + ) + + async def lifespan_scoped_dependency( + param: Annotated[int, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan" + )] + ) -> AsyncGenerator[int, None]: + yield param + + app = FastAPI() + + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + int, + Depends(lifespan_scoped_dependency, dependency_scope="lifespan") + ], + ) + + app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + override_dependency_factory=override_dependency_factory, + expected_activation_times=1, + urls_and_responses=[("/test", 11)] * 2, + is_websocket=is_websocket + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("depends_class", [Depends, Security]) +def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( + depends_class, + is_websocket +): + async def sub_dependency() -> None: + pass + + async def dependency_func() -> None: + yield + + async def override_dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: + yield + + app = FastAPI() + + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] + ) + + app.dependency_overrides[dependency_func] = override_dependency_func + + with pytest.raises(DependencyScopeConflict): + with TestClient(app): + pass + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_bad_override_lifespan_scoped_dependencies( + use_cache, + dependency_style: DependencyStyle, + routing_style, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, should_error=True) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends] + ) + + if routing_style == "router_endpoint": + app.include_router(router) + + app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + + with pytest.raises(IntentionallyBadDependency) as exception_info: + with TestClient(app): + pass + + assert exception_info.value.args == (1,) diff --git a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py new file mode 100644 index 000000000..ccf8d896a --- /dev/null +++ b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py @@ -0,0 +1,854 @@ +import warnings +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator, Dict, List, Tuple + +import pytest +from fastapi import ( + APIRouter, + BackgroundTasks, + Body, + Cookie, + Depends, + FastAPI, + File, + Form, + Header, + Path, + Query, +) +from fastapi.exceptions import ( + DependencyScopeConflict, + InvalidDependencyScope, + UninitializedLifespanDependency, +) +from fastapi.params import Security +from fastapi.security import SecurityScopes +from fastapi.testclient import TestClient +from typing_extensions import Annotated, Literal, assert_never + +from tests.test_lifespan_scoped_dependencies.testing_utilities import ( + DependencyFactory, + DependencyStyle, + IntentionallyBadDependency, + create_endpoint_0_annotations, + create_endpoint_1_annotation, + create_endpoint_2_annotations, + create_endpoint_3_annotations, + use_endpoint, + use_websocket, +) + + +def expect_correct_amount_of_dependency_activations( + *, + app: FastAPI, + dependency_factory: DependencyFactory, + urls_and_responses: List[Tuple[str, Any]], + expected_activation_times: int, + is_websocket: bool +) -> None: + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + with TestClient(app) as client: + assert dependency_factory.activation_times == expected_activation_times + assert dependency_factory.deactivation_times == 0 + + for url, expected_response in urls_and_responses: + if is_websocket: + assert use_websocket(client, url) == expected_response + else: + assert use_endpoint(client, url) == expected_response + + assert dependency_factory.activation_times == expected_activation_times + assert dependency_factory.deactivation_times == 0 + + assert dependency_factory.activation_times == expected_activation_times + if dependency_factory.dependency_style not in ( + DependencyStyle.SYNC_FUNCTION, + DependencyStyle.ASYNC_FUNCTION + ): + assert dependency_factory.deactivation_times == expected_activation_times + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False], ids=["With Cache", "Without Cache"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoint_dependencies( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket: bool, +): + dependency_factory = DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[None, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + )], + expected_value=1 + ) + + if routing_style == "router_endpoint": + app.include_router(router) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + urls_and_responses=[("/test", 1)] * 2, + expected_activation_times=1, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("dependency_duplication", [1, 2]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_router_dependencies( + dependency_style: DependencyStyle, + routing_style, + use_cache, + dependency_duplication, + is_websocket: bool, +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + if routing_style == "app": + app = FastAPI(dependencies=[depends] * dependency_duplication) + + create_endpoint_0_annotations( + router=app, + path="/test", + is_websocket=is_websocket + ) + else: + app = FastAPI() + router = APIRouter(dependencies=[depends] * dependency_duplication) + + create_endpoint_0_annotations( + router=router, + path="/test", + is_websocket=is_websocket + ) + + app.include_router(router) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + urls_and_responses=[("/test", None)] * 2, + expected_activation_times=1 if use_cache else dependency_duplication, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) +def test_dependency_cache_in_same_dependency( + dependency_style: DependencyStyle, + routing_style, + use_cache, + main_dependency_scope: Literal["endpoint", "lifespan"], + is_websocket: bool, +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def dependency( + sub_dependency1: Annotated[int, depends], + sub_dependency2: Annotated[int, depends], + ) -> List[int]: + return [sub_dependency1, sub_dependency2] + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[List[int], Depends( + dependency, + use_cache=use_cache, + dependency_scope=main_dependency_scope, + )] + ) + + if routing_style == "router": + app.include_router(router) + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 1]), + ("/test", [1, 1]), + ], + dependency_factory=dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 2]), + ("/test", [1, 2]), + ], + dependency_factory=dependency_factory, + expected_activation_times=2, + is_websocket=is_websocket + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_same_endpoint( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + create_endpoint_3_annotations( + router=router, + path="/test", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)] + ) + + if routing_style == "router": + app.include_router(router) + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 1, 1]), + ("/test", [1, 1, 1]), + ], + dependency_factory=dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test", [1, 2, 3]), + ("/test", [1, 2, 3]), + ], + dependency_factory=dependency_factory, + expected_activation_times=3, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_dependency_cache_in_different_endpoints( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: + return dependency3 + + create_endpoint_3_annotations( + router=router, + path="/test1", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)] + ) + + create_endpoint_3_annotations( + router=router, + path="/test2", + is_websocket=is_websocket, + annotation1=Annotated[int, depends], + annotation2=Annotated[int, depends], + annotation3=Annotated[int, Depends(endpoint_dependency)] + ) + + if routing_style == "router": + app.include_router(router) + + if use_cache: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [1, 1, 1]), + ("/test2", [1, 1, 1]), + ("/test1", [1, 1, 1]), + ("/test2", [1, 1, 1]), + ], + dependency_factory=dependency_factory, + expected_activation_times=1, + is_websocket=is_websocket + ) + else: + expect_correct_amount_of_dependency_activations( + app=app, + urls_and_responses=[ + ("/test1", [1, 2, 3]), + ("/test2", [4, 5, 3]), + ("/test1", [1, 2, 3]), + ("/test2", [4, 5, 3]), + ], + dependency_factory=dependency_factory, + expected_activation_times=5, + is_websocket=is_websocket + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_no_cached_dependency( + dependency_style: DependencyStyle, + routing_style, + is_websocket, +): + dependency_factory= DependencyFactory(dependency_style) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=False + ) + + app = FastAPI() + + if routing_style == "app": + router = app + + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + expected_value=1 + ) + + if routing_style == "router": + app.include_router(router) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + urls_and_responses=[("/test", 1)] * 2, + expected_activation_times=1, + is_websocket=is_websocket + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("annotation", [ + Annotated[str, Path()], + Annotated[str, Body()], + Annotated[str, Query()], + Annotated[str, Header()], + SecurityScopes, + Annotated[str, Cookie()], + Annotated[str, Form()], + Annotated[str, File()], + BackgroundTasks, +]) +def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( + annotation, + is_websocket +): + async def dependency_func(param: annotation) -> None: + yield + + app = FastAPI() + + with pytest.raises(DependencyScopeConflict): + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[ + None, + Depends(dependency_func, dependency_scope="lifespan") + ], + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( + dependency_style: DependencyStyle, + is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + + async def lifespan_scoped_dependency( + param: Annotated[int, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan" + )] + ) -> AsyncGenerator[int, None]: + yield param + + app = FastAPI() + + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, Depends(lifespan_scoped_dependency)], + expected_value=1 + ) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + expected_activation_times=1, + urls_and_responses=[("/test", 1)] * 2, + is_websocket=is_websocket + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize([ + "dependency_style", + "supports_teardown" +], [ + (DependencyStyle.SYNC_FUNCTION, False), + (DependencyStyle.ASYNC_FUNCTION, False), + (DependencyStyle.SYNC_GENERATOR, True), + (DependencyStyle.ASYNC_GENERATOR, True), +]) +def test_the_same_dependency_can_work_in_different_scopes( + dependency_style: DependencyStyle, + supports_teardown, + is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + app = FastAPI() + + create_endpoint_2_annotations( + router=app, + path="/test", + is_websocket=is_websocket, + annotation1=Annotated[int, Depends( + dependency_factory.get_dependency(), + dependency_scope="endpoint" + )], + annotation2=Annotated[int, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan" + )], + ) + if is_websocket: + get_response = use_websocket + else: + get_response = use_endpoint + + assert dependency_factory.activation_times == 0 + assert dependency_factory.deactivation_times == 0 + with TestClient(app) as client: + assert dependency_factory.activation_times == 1 + assert dependency_factory.deactivation_times == 0 + + assert get_response(client, "/test") == [2, 1] + assert dependency_factory.activation_times == 2 + if supports_teardown: + assert dependency_factory.deactivation_times == 1 + else: + assert dependency_factory.deactivation_times == 0 + + assert get_response(client, "/test") == [3, 1] + assert dependency_factory.activation_times == 3 + if supports_teardown: + assert dependency_factory.deactivation_times == 2 + else: + assert dependency_factory.deactivation_times == 0 + + assert dependency_factory.activation_times == 3 + if supports_teardown: + assert dependency_factory.deactivation_times == 3 + else: + assert dependency_factory.deactivation_times == 0 + + +@pytest.mark.parametrize("lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"]) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( + dependency_style: DependencyStyle, + is_websocket, + lifespan_style: Literal["lifespan_function", "lifespan_events"] +): + lifespan_started = False + lifespan_ended = False + if lifespan_style == "lifespan_generator": + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, int], None]: + nonlocal lifespan_started + nonlocal lifespan_ended + lifespan_started = True + yield + lifespan_ended = True + + app = FastAPI(lifespan=lifespan) + elif lifespan_style == "events_decorator": + app = FastAPI() + with warnings.catch_warnings(action="ignore", category=DeprecationWarning): + @app.on_event("startup") + async def startup() -> None: + nonlocal lifespan_started + lifespan_started = True + + @app.on_event("shutdown") + async def shutdown() -> None: + nonlocal lifespan_ended + lifespan_ended = True + elif lifespan_style == "events_constructor": + async def startup() -> None: + nonlocal lifespan_started + lifespan_started = True + + async def shutdown() -> None: + nonlocal lifespan_ended + lifespan_ended = True + app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) + else: + assert_never(lifespan_style) + + dependency_factory = DependencyFactory(dependency_style) + + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan" + )], + expected_value=1 + ) + + expect_correct_amount_of_dependency_activations( + app=app, + dependency_factory=dependency_factory, + expected_activation_times=1, + urls_and_responses=[("/test", 1)] * 2, + is_websocket=is_websocket + ) + assert lifespan_started and lifespan_ended + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("depends_class", [Depends, Security]) +def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( + depends_class, + is_websocket +): + async def sub_dependency() -> None: + pass + + async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: + yield + + app = FastAPI() + + with pytest.raises(DependencyScopeConflict): + create_endpoint_1_annotation( + router=app, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")], + ) + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_dependencies_must_provide_correct_dependency_scope( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + with pytest.raises( + InvalidDependencyScope, + match=r'Dependency "value" of .* has an invalid scope: ' + r'"incorrect"' + ): + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[None, Depends( + dependency_factory.get_dependency(), + dependency_scope="incorrect", + use_cache=use_cache, + )] + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoints_report_incorrect_dependency_scope( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + # We intentionally change the dependency scope here to bypass the + # validation at the function level. + depends.dependency_scope = "asdad" + + with pytest.raises(InvalidDependencyScope): + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends] + ) + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoints_report_uninitialized_dependency( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + expected_value=1 + ) + + if routing_style == "router_endpoint": + app.include_router(router) + + with TestClient(app) as client: + dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] + client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {} + + try: + with pytest.raises(UninitializedLifespanDependency): + if is_websocket: + with client.websocket_connect("/test"): + pass + else: + client.post("/test") + finally: + client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_endpoints_report_uninitialized_internal_lifespan( + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket +): + dependency_factory = DependencyFactory(dependency_style) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + else: + router = APIRouter() + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + expected_value=1 + ) + + if routing_style == "router_endpoint": + app.include_router(router) + + with TestClient(app) as client: + internal_state = client.app_state["__fastapi__"] + del client.app_state["__fastapi__"] + + try: + with pytest.raises(UninitializedLifespanDependency): + if is_websocket: + with client.websocket_connect("/test"): + pass + else: + client.post("/test") + finally: + client.app_state["__fastapi__"] = internal_state + + +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) +def test_bad_lifespan_scoped_dependencies( + use_cache, + dependency_style: DependencyStyle, + routing_style, + is_websocket +): + dependency_factory= DependencyFactory(dependency_style, should_error=True) + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan", + use_cache=use_cache, + ) + + app = FastAPI() + + if routing_style == "app_endpoint": + router = app + + else: + router = APIRouter() + + create_endpoint_1_annotation( + router=router, + path="/test", + is_websocket=is_websocket, + annotation=Annotated[int, depends], + expected_value=1 + ) + + if routing_style == "router_endpoint": + app.include_router(router) + + with pytest.raises(IntentionallyBadDependency) as exception_info: + with TestClient(app): + pass + + assert exception_info.value.args == (1,) diff --git a/tests/test_lifespan_scoped_dependencies/testing_utilities.py b/tests/test_lifespan_scoped_dependencies/testing_utilities.py new file mode 100644 index 000000000..e733205f5 --- /dev/null +++ b/tests/test_lifespan_scoped_dependencies/testing_utilities.py @@ -0,0 +1,202 @@ +from enum import StrEnum, auto +from typing import Any, AsyncGenerator, Generator, TypeVar, Union, assert_never + +from fastapi import APIRouter, FastAPI, WebSocket +from starlette.testclient import TestClient +from starlette.websockets import WebSocketDisconnect + +T = TypeVar('T') + + +class DependencyStyle(StrEnum): + SYNC_FUNCTION = auto() + ASYNC_FUNCTION = auto() + SYNC_GENERATOR = auto() + ASYNC_GENERATOR = auto() + + +class IntentionallyBadDependency(Exception): + pass + + +class DependencyFactory: + def __init__( + self, + dependency_style: DependencyStyle, *, + should_error: bool = False, + value_offset: int = 0, + ): + self.activation_times = 0 + self.deactivation_times = 0 + self.dependency_style = dependency_style + self._should_error = should_error + self._value_offset = value_offset + + def get_dependency(self): + if self.dependency_style == DependencyStyle.SYNC_FUNCTION: + return self._synchronous_function_dependency + + if self.dependency_style == DependencyStyle.SYNC_GENERATOR: + return self._synchronous_generator_dependency + + if self.dependency_style == DependencyStyle.ASYNC_FUNCTION: + return self._asynchronous_function_dependency + + if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: + return self._asynchronous_generator_dependency + + assert_never(self.dependency_style) + + async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: + self.activation_times += 1 + if self._should_error: + raise IntentionallyBadDependency(self.activation_times) + + yield self.activation_times + self._value_offset + self.deactivation_times += 1 + + def _synchronous_generator_dependency(self) -> Generator[T, None, None]: + self.activation_times += 1 + if self._should_error: + raise IntentionallyBadDependency(self.activation_times) + + yield self.activation_times + self._value_offset + self.deactivation_times += 1 + + async def _asynchronous_function_dependency(self) -> T: + self.activation_times += 1 + if self._should_error: + raise IntentionallyBadDependency(self.activation_times) + + return self.activation_times + self._value_offset + + def _synchronous_function_dependency(self) -> T: + self.activation_times += 1 + if self._should_error: + raise IntentionallyBadDependency(self.activation_times) + + return self.activation_times + self._value_offset + + +def use_endpoint(client: TestClient, url: str) -> Any: + response = client.post(url) + response.raise_for_status() + return response.json() + + +def use_websocket(client: TestClient, url: str) -> Any: + with client.websocket_connect(url) as connection: + return connection.receive_json() + + +def create_endpoint_0_annotations( + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, +) -> None: + if is_websocket: + @router.websocket(path) + async def endpoint(websocket: WebSocket) -> None: + await websocket.accept() + try: + await websocket.send_json(None) + except WebSocketDisconnect: + pass + else: + @router.post(path) + async def endpoint() -> None: + return None + + +def create_endpoint_1_annotation( + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, + annotation: Any, + expected_value: Any = None +) -> None: + if is_websocket: + @router.websocket(path) + async def endpoint( + websocket: WebSocket, + value: annotation + ) -> None: + if expected_value is not None: + assert value == expected_value + + await websocket.accept() + try: + await websocket.send_json(value) + except WebSocketDisconnect: + pass + else: + @router.post(path) + async def endpoint( + value: annotation + ) -> None: + if expected_value is not None: + assert value == expected_value + + return value + +def create_endpoint_2_annotations( + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, + annotation1: Any, + annotation2: Any, +) -> None: + if is_websocket: + @router.websocket(path) + async def endpoint( + websocket: WebSocket, + value1: annotation1, + value2: annotation2, + ) -> None: + await websocket.accept() + try: + await websocket.send_json([value1, value2]) + except WebSocketDisconnect: + await websocket.close() + else: + @router.post(path) + async def endpoint( + value1: annotation1, + value2: annotation2, + ) -> list[Any]: + return [value1, value2] + + +def create_endpoint_3_annotations( + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, + annotation1: Any, + annotation2: Any, + annotation3: Any +) -> None: + if is_websocket: + @router.websocket(path) + async def endpoint( + websocket: WebSocket, + value1: annotation1, + value2: annotation2, + value3: annotation3 + ) -> None: + await websocket.accept() + try: + await websocket.send_json([value1, value2, value3]) + except WebSocketDisconnect: + await websocket.close() + else: + @router.post(path) + async def endpoint( + value1: annotation1, + value2: annotation2, + value3: annotation3 + ) -> list[Any]: + return [value1, value2, value3]