From 54ecfb87d87cdc7c087c1381f2f543be7e6ed039 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Oct 2024 13:12:24 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20format?= =?UTF-8?q?=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/applications.py | 11 +- fastapi/dependencies/models.py | 14 +- fastapi/dependencies/utils.py | 59 ++--- fastapi/lifespan.py | 6 +- fastapi/param_functions.py | 8 +- fastapi/params.py | 8 +- fastapi/routing.py | 16 +- tests/test_lifespan_scoped_dependencies.py | 261 +++++++++++---------- tests/test_params_repr.py | 36 ++- tests/test_router_events.py | 14 +- 10 files changed, 214 insertions(+), 219 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index 625690a74..6b2b90336 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -938,10 +938,7 @@ class FastAPI(Starlette): if lifespan is None: lifespan = FastAPI._internal_lifespan else: - lifespan = merge_lifespan_context( - FastAPI._internal_lifespan, - lifespan - ) + lifespan = merge_lifespan_context(FastAPI._internal_lifespan, lifespan) # Since we always use a lifespan, starlette will no longer run event # handlers which are defined in the scope of the application. @@ -985,8 +982,7 @@ class FastAPI(Starlette): async def _internal_lifespan(self) -> AsyncGenerator[dict[str, Any], None]: async with AsyncExitStack() as exit_stack: lifespan_scoped_dependencies = await resolve_lifespan_dependants( - app=self, - async_exit_stack=exit_stack + app=self, async_exit_stack=exit_stack ) try: for handler in self._on_startup: @@ -1006,7 +1002,6 @@ class FastAPI(Starlette): else: await run_in_threadpool(handler) - def openapi(self) -> Dict[str, Any]: """ Generate the OpenAPI schema of the application. This is called by FastAPI @@ -4536,12 +4531,14 @@ class FastAPI(Starlette): Read more about it in the [FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/#alternative-events-deprecated). """ + def decorator(func: DecoratedCallable) -> DecoratedCallable: if event_type == "startup": self._on_startup.append(func) else: self._on_shutdown.append(func) return func + return decorator def middleware( diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 471f9c402..df72e7f5f 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -12,7 +12,10 @@ class SecurityRequirement: scopes: Optional[Sequence[str]] = None -LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], str], Callable[..., Any]] +LifespanDependantCacheKey: TypeAlias = Union[ + Tuple[Callable[..., Any], str], Callable[..., Any] +] + @dataclass class LifespanDependant: @@ -30,7 +33,10 @@ class LifespanDependant: self.cache_key = (self.caller, self.name) -EndpointDependantCacheKey: TypeAlias = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] +EndpointDependantCacheKey: TypeAlias = Tuple[ + Optional[Callable[..., Any]], Tuple[str, ...] +] + @dataclass class EndpointDependant: @@ -39,8 +45,7 @@ class EndpointDependant: name: Optional[str] = None call: Optional[Callable[..., Any]] = None use_cache: bool = True - cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field( - init=False) + cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) path_params: List[ModelField] = field(default_factory=list) query_params: List[ModelField] = field(default_factory=list) header_params: List[ModelField] = field(default_factory=list) @@ -64,6 +69,7 @@ class EndpointDependant: def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]: return tuple(self.endpoint_dependencies + self.lifespan_dependencies) + # Kept for backwards compatibility Dependant = EndpointDependant CacheKey: TypeAlias = Union[EndpointDependantCacheKey, LifespanDependantCacheKey] diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 0e9cfe244..527224584 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -134,19 +134,13 @@ def get_param_sub_dependant( def get_parameterless_sub_dependant( - *, - depends: params.Depends, - path: str, - caller: Callable[..., Any] + *, depends: params.Depends, path: str, caller: Callable[..., Any] ) -> Union[EndpointDependant, LifespanDependant]: assert callable( depends.dependency ), "A parameter-less dependency must have a callable dependency" return get_sub_dependant( - depends=depends, - dependency=depends.dependency, - path=path, - caller=caller + depends=depends, dependency=depends.dependency, path=path, caller=caller ) @@ -164,7 +158,7 @@ def get_sub_dependant( caller=caller, call=depends.dependency, name=name, - use_cache=depends.use_cache + use_cache=depends.use_cache, ) elif depends.dependency_scope == "endpoint": security_requirement = None @@ -215,7 +209,7 @@ def get_flat_dependant( security_requirements=dependant.security_requirements.copy(), lifespan_dependencies=dependant.lifespan_dependencies.copy(), use_cache=dependant.use_cache, - path=dependant.path + path=dependant.path, ) for sub_dependant in dependant.endpoint_dependencies: if skip_repeats and sub_dependant.cache_key in visited: @@ -296,10 +290,7 @@ def get_lifespan_dependant( dependency_signature = get_typed_signature(call) signature_params = dependency_signature.parameters dependant = LifespanDependant( - call=call, - name=name, - use_cache=use_cache, - caller=caller + call=call, name=name, use_cache=use_cache, caller=caller ) for param_name, param in signature_params.items(): param_details = analyze_param( @@ -312,26 +303,25 @@ def get_lifespan_dependant( raise FastAPIError( f"Lifespan dependency {dependant.name} was defined with an " f"invalid argument {param_name}. Lifespan dependencies may " - f"only use other lifespan dependencies as arguments.") + f"only use other lifespan dependencies as arguments." + ) if param_details.depends.dependency_scope != "lifespan": raise FastAPIError( - "Lifespan dependency may not use " - "sub-dependencies of other scopes." + "Lifespan dependency may not use " "sub-dependencies of other scopes." ) sub_dependant = get_lifespan_dependant( name=param_name, call=param_details.depends.dependency, use_cache=param_details.depends.use_cache, - caller=call + caller=call, ) dependant.dependencies.append(sub_dependant) return dependant - def get_endpoint_dependant( *, path: str, @@ -378,8 +368,8 @@ def get_endpoint_dependant( dependant.lifespan_dependencies.append(sub_dependant) else: raise FastAPIError( - f"Dependency \"{param_name}\" of `{call}` has an invalid " - f"sub-dependency scope: \"{param_details.depends.dependency_scope}\"" + f'Dependency "{param_name}" of `{call}` has an invalid ' + f'sub-dependency scope: "{param_details.depends.dependency_scope}"' ) continue if add_non_field_param_to_dependency( @@ -659,7 +649,9 @@ async def solve_lifespan_dependant( *, dependant: LifespanDependant, dependency_overrides_provider: Optional[Any] = None, - dependency_cache: Optional[Dict[LifespanDependantCacheKey, Callable[..., Any]]] = None, + dependency_cache: Optional[ + Dict[LifespanDependantCacheKey, Callable[..., Any]] + ] = None, async_exit_stack: AsyncExitStack, ) -> SolvedLifespanDependant: dependency_cache = dependency_cache or {} @@ -673,9 +665,7 @@ async def solve_lifespan_dependant( sub_dependant: LifespanDependant for sub_dependant in dependant.dependencies: sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) - sub_dependant.cache_key = cast( - Callable[..., Any], sub_dependant.cache_key - ) + sub_dependant.cache_key = cast(Callable[..., Any], sub_dependant.cache_key) assert sub_dependant.name, ( "Lifespan scoped dependencies should not be able to have " "subdependencies with no name" @@ -691,9 +681,7 @@ async def solve_lifespan_dependant( dependency_overrides_provider, "dependency_overrides", {} ).get(original_call, original_call) sub_dependant_to_solve = get_lifespan_dependant( - call=call, - name=sub_dependant.name, - caller=dependant.call + call=call, name=sub_dependant.name, caller=dependant.call ) solved_sub_dependant = await solve_lifespan_dependant( @@ -707,9 +695,7 @@ async def solve_lifespan_dependant( if is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call): value = await solve_generator( - call=dependant.call, - stack=async_exit_stack, - sub_values=dependency_arguments + call=dependant.call, stack=async_exit_stack, sub_values=dependency_arguments ) elif is_coroutine_callable(dependant.call): value = await dependant.call(**dependency_arguments) @@ -754,18 +740,17 @@ async def solve_dependencies( continue try: lifespan_scoped_dependencies = request.state.__fastapi__[ - "lifespan_scoped_dependencies"] + "lifespan_scoped_dependencies" + ] except AttributeError as e: - raise FastAPIError( - "FastAPI's internal lifespan was not initialized" - ) from e + raise FastAPIError("FastAPI's internal lifespan was not initialized") from e try: value = lifespan_scoped_dependencies[sub_dependant.cache_key] except KeyError as e: raise FastAPIError( - f"Dependency {sub_dependant.name} of {dependant.call} " - f"was not initialized." + f"Dependency {sub_dependant.name} of {dependant.call} " + f"was not initialized." ) from e values[sub_dependant.name] = value diff --git a/fastapi/lifespan.py b/fastapi/lifespan.py index 184c943d7..bff285267 100644 --- a/fastapi/lifespan.py +++ b/fastapi/lifespan.py @@ -27,9 +27,7 @@ def _get_lifespan_dependants(app: FastAPI) -> List[LifespanDependant]: async def resolve_lifespan_dependants( - *, - app: FastAPI, - async_exit_stack: AsyncExitStack + *, app: FastAPI, async_exit_stack: AsyncExitStack ) -> Dict[LifespanDependantCacheKey, Callable[..., Any]]: lifespan_dependants = _get_lifespan_dependants(app) dependency_cache: Dict[LifespanDependantCacheKey, Callable[..., Any]] = {} @@ -38,7 +36,7 @@ async def resolve_lifespan_dependants( dependant=lifespan_dependant, dependency_overrides_provider=app, dependency_cache=dependency_cache, - async_exit_stack=async_exit_stack + async_exit_stack=async_exit_stack, ) dependency_cache.update(solved_dependency.dependency_cache) diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index e9a15c166..4112e90e2 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -2272,8 +2272,8 @@ def Depends( # noqa: N802 or any other annotation which does not make sense in a scope of an application's entire lifespan. """ - ) - ] = "endpoint" + ), + ] = "endpoint", ) -> Any: """ Declare a FastAPI dependency. @@ -2304,7 +2304,9 @@ def Depends( # noqa: N802 return commons ``` """ - return params.Depends(dependency=dependency, use_cache=use_cache, dependency_scope=dependency_scope) + return params.Depends( + dependency=dependency, use_cache=use_cache, dependency_scope=dependency_scope + ) def Security( # noqa: N802 diff --git a/fastapi/params.py b/fastapi/params.py index f655acba8..954d4b27d 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -764,7 +764,7 @@ class Depends: dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True, - dependency_scope: DependencyScope = "endpoint" + dependency_scope: DependencyScope = "endpoint", ): self.dependency = dependency self.use_cache = use_cache @@ -776,7 +776,7 @@ class Depends: if self.dependency_scope == "endpoint": dependency_scope = "" else: - dependency_scope = f", dependency_scope=\"{self.dependency_scope}\"" + dependency_scope = f', dependency_scope="{self.dependency_scope}"' return f"{self.__class__.__name__}({attr}{cache}{dependency_scope})" @@ -790,8 +790,6 @@ class Security(Depends): use_cache: bool = True, ): super().__init__( - dependency=dependency, - use_cache=use_cache, - dependency_scope="endpoint" + dependency=dependency, use_cache=use_cache, dependency_scope="endpoint" ) self.scopes = scopes or [] diff --git a/fastapi/routing.py b/fastapi/routing.py index 50b51a352..24ea39268 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -400,12 +400,12 @@ class APIWebSocketRoute(routing.WebSocketRoute): self.name = get_name(endpoint) if name is None else name self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) - self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_endpoint_dependant( + path=self.path_format, call=self.endpoint + ) for depends in self.dependencies[::-1]: sub_dependant = get_parameterless_sub_dependant( - depends=depends, - path=self.path_format, - caller=self + depends=depends, path=self.path_format, caller=self ) if depends.dependency_scope == "endpoint": self.dependant.endpoint_dependencies.insert(0, sub_dependant) @@ -563,12 +563,12 @@ class APIRoute(routing.Route): self.response_fields = {} assert callable(endpoint), "An endpoint must be a callable" - self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_endpoint_dependant( + path=self.path_format, call=self.endpoint + ) for depends in self.dependencies[::-1]: sub_dependant = get_parameterless_sub_dependant( - depends=depends, - path=self.path_format, - caller=self.__call__ + depends=depends, path=self.path_format, caller=self.__call__ ) if depends.dependency_scope == "endpoint": self.dependant.endpoint_dependencies.insert(0, sub_dependant) diff --git a/tests/test_lifespan_scoped_dependencies.py b/tests/test_lifespan_scoped_dependencies.py index 40d82f0e0..70cbc5bb3 100644 --- a/tests/test_lifespan_scoped_dependencies.py +++ b/tests/test_lifespan_scoped_dependencies.py @@ -21,7 +21,7 @@ from fastapi.security import SecurityScopes from starlette.testclient import TestClient from typing_extensions import Annotated, Generator, Literal, assert_never -T = TypeVar('T') +T = TypeVar("T") class DependencyStyle(StrEnum): @@ -33,9 +33,7 @@ class DependencyStyle(StrEnum): class DependencyFactory: def __init__( - self, - dependency_style: DependencyStyle, *, - should_error: bool = False + self, dependency_style: DependencyStyle, *, should_error: bool = False ): self.activation_times = 0 self.deactivation_times = 0 @@ -89,11 +87,11 @@ class DependencyFactory: def _expect_correct_amount_of_dependency_activations( - *, - app: FastAPI, - dependency_factory: DependencyFactory, - urls_and_responses: List[Tuple[str, Any]], - expected_activation_times: int + *, + app: FastAPI, + dependency_factory: DependencyFactory, + urls_and_responses: List[Tuple[str, Any]], + expected_activation_times: int, ) -> None: assert dependency_factory.activation_times == 0 assert dependency_factory.deactivation_times == 0 @@ -111,16 +109,19 @@ def _expect_correct_amount_of_dependency_activations( assert dependency_factory.activation_times == expected_activation_times if dependency_factory.dependency_style not in ( - DependencyStyle.SYNC_FUNCTION, - DependencyStyle.ASYNC_FUNCTION + DependencyStyle.SYNC_FUNCTION, + DependencyStyle.ASYNC_FUNCTION, ): assert dependency_factory.deactivation_times == expected_activation_times + @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, use_cache): - dependency_factory= DependencyFactory(dependency_style) +def test_endpoint_dependencies( + dependency_style: DependencyStyle, routing_style, use_cache +): + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -131,11 +132,14 @@ def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, @router.post("/test") async def endpoint( - dependency: Annotated[None, Depends( + dependency: Annotated[ + None, + Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", use_cache=use_cache, - )] + ), + ], ) -> None: assert dependency == 1 return dependency @@ -147,23 +151,22 @@ def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, app=app, dependency_factory=dependency_factory, urls_and_responses=[("/test", 1)] * 2, - expected_activation_times=1 + expected_activation_times=1, ) + @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_router_dependencies( - dependency_style: DependencyStyle, - routing_style, - use_cache + dependency_style: DependencyStyle, routing_style, use_cache ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) if routing_style == "app": @@ -186,7 +189,7 @@ def test_router_dependencies( app=app, dependency_factory=dependency_factory, urls_and_responses=[("/test", None)] * 2, - expected_activation_times=1 + expected_activation_times=1, ) @@ -195,17 +198,17 @@ def test_router_dependencies( @pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) def test_dependency_cache_in_same_dependency( - dependency_style: DependencyStyle, - routing_style, - use_cache, - main_dependency_scope: Literal["endpoint", "lifespan"] + dependency_style: DependencyStyle, + routing_style, + use_cache, + main_dependency_scope: Literal["endpoint", "lifespan"], ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -217,18 +220,21 @@ def test_dependency_cache_in_same_dependency( router = APIRouter() async def dependency( - sub_dependency1: Annotated[int, depends], - sub_dependency2: Annotated[int, depends], + sub_dependency1: Annotated[int, depends], + sub_dependency2: Annotated[int, depends], ) -> List[int]: return [sub_dependency1, sub_dependency2] @router.post("/test") async def endpoint( - dependency: Annotated[List[int], Depends( + dependency: Annotated[ + List[int], + Depends( dependency, use_cache=use_cache, dependency_scope=main_dependency_scope, - )] + ), + ], ) -> List[int]: return dependency @@ -243,7 +249,7 @@ def test_dependency_cache_in_same_dependency( ("/test", [1, 1]), ], dependency_factory=dependency_factory, - expected_activation_times=1 + expected_activation_times=1, ) else: _expect_correct_amount_of_dependency_activations( @@ -253,7 +259,7 @@ def test_dependency_cache_in_same_dependency( ("/test", [1, 2]), ], dependency_factory=dependency_factory, - expected_activation_times=2 + expected_activation_times=2, ) @@ -261,16 +267,14 @@ def test_dependency_cache_in_same_dependency( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_dependency_cache_in_same_endpoint( - dependency_style: DependencyStyle, - routing_style, - use_cache + dependency_style: DependencyStyle, routing_style, use_cache ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -286,9 +290,9 @@ def test_dependency_cache_in_same_endpoint( @router.post("/test1") async def endpoint( - dependency1: Annotated[int, depends], - dependency2: Annotated[int, depends], - dependency3: Annotated[int, Depends(endpoint_dependency)] + dependency1: Annotated[int, depends], + dependency2: Annotated[int, depends], + dependency3: Annotated[int, Depends(endpoint_dependency)], ) -> List[int]: return [dependency1, dependency2, dependency3] @@ -303,7 +307,7 @@ def test_dependency_cache_in_same_endpoint( ("/test1", [1, 1, 1]), ], dependency_factory=dependency_factory, - expected_activation_times=1 + expected_activation_times=1, ) else: _expect_correct_amount_of_dependency_activations( @@ -313,23 +317,22 @@ def test_dependency_cache_in_same_endpoint( ("/test1", [1, 2, 3]), ], dependency_factory=dependency_factory, - expected_activation_times=3 + expected_activation_times=3, ) + @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_dependency_cache_in_different_endpoints( - dependency_style: DependencyStyle, - routing_style, - use_cache + dependency_style: DependencyStyle, routing_style, use_cache ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -345,17 +348,17 @@ def test_dependency_cache_in_different_endpoints( @router.post("/test1") async def endpoint( - dependency1: Annotated[int, depends], - dependency2: Annotated[int, depends], - dependency3: Annotated[int, Depends(endpoint_dependency)] + dependency1: Annotated[int, depends], + dependency2: Annotated[int, depends], + dependency3: Annotated[int, Depends(endpoint_dependency)], ) -> List[int]: return [dependency1, dependency2, dependency3] @router.post("/test2") async def endpoint2( - dependency1: Annotated[int, depends], - dependency2: Annotated[int, depends], - dependency3: Annotated[int, Depends(endpoint_dependency)] + dependency1: Annotated[int, depends], + dependency2: Annotated[int, depends], + dependency3: Annotated[int, Depends(endpoint_dependency)], ) -> List[int]: return [dependency1, dependency2, dependency3] @@ -372,7 +375,7 @@ def test_dependency_cache_in_different_endpoints( ("/test2", [1, 1, 1]), ], dependency_factory=dependency_factory, - expected_activation_times=1 + expected_activation_times=1, ) else: _expect_correct_amount_of_dependency_activations( @@ -384,21 +387,22 @@ def test_dependency_cache_in_different_endpoints( ("/test2", [4, 5, 3]), ], dependency_factory=dependency_factory, - expected_activation_times=5 + expected_activation_times=5, ) + @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_no_cached_dependency( - dependency_style: DependencyStyle, - routing_style, + dependency_style: DependencyStyle, + routing_style, ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=False + use_cache=False, ) app = FastAPI() @@ -411,7 +415,7 @@ def test_no_cached_dependency( @router.post("/test") async def endpoint( - dependency: Annotated[int, depends], + dependency: Annotated[int, depends], ) -> int: return dependency @@ -422,49 +426,52 @@ def test_no_cached_dependency( app=app, dependency_factory=dependency_factory, urls_and_responses=[("/test", 1)] * 2, - expected_activation_times=1 + expected_activation_times=1, ) -@pytest.mark.parametrize("annotation", [ - Annotated[str, Path()], - Annotated[str, Body()], - Annotated[str, Query()], - Annotated[str, Header()], - SecurityScopes, - Annotated[str, Cookie()], - Annotated[str, Form()], - Annotated[str, File()], - BackgroundTasks, -]) -def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( - annotation -): +@pytest.mark.parametrize( + "annotation", + [ + Annotated[str, Path()], + Annotated[str, Body()], + Annotated[str, Query()], + Annotated[str, Header()], + SecurityScopes, + Annotated[str, Cookie()], + Annotated[str, Form()], + Annotated[str, File()], + BackgroundTasks, + ], +) +def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(annotation): async def dependency_func(param: annotation) -> None: yield app = FastAPI() with pytest.raises(FastAPIError): + @app.post("/test") async def endpoint( - dependency: Annotated[ - None, Depends(dependency_func, dependency_scope="lifespan")] + dependency: Annotated[ + None, Depends(dependency_func, dependency_scope="lifespan") + ], ) -> None: return @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( - dependency_style: DependencyStyle + dependency_style: DependencyStyle, ): dependency_factory = DependencyFactory(dependency_style) async def lifespan_scoped_dependency( - param: Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - )] + param: Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], ) -> AsyncGenerator[int, None]: yield param @@ -472,10 +479,9 @@ def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( @app.post("/test") async def endpoint( - dependency: Annotated[int, Depends( - lifespan_scoped_dependency, - dependency_scope="lifespan" - )] + dependency: Annotated[ + int, Depends(lifespan_scoped_dependency, dependency_scope="lifespan") + ], ) -> int: return dependency @@ -483,42 +489,44 @@ def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( app=app, dependency_factory=dependency_factory, expected_activation_times=1, - urls_and_responses=[("/test", 1)] * 2 + urls_and_responses=[("/test", 1)] * 2, ) @pytest.mark.parametrize("depends_class", [Depends, Security]) -@pytest.mark.parametrize("route_type", [FastAPI.post, FastAPI.websocket], ids=[ - "websocket", "endpoint" -]) +@pytest.mark.parametrize( + "route_type", [FastAPI.post, FastAPI.websocket], ids=["websocket", "endpoint"] +) def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( - depends_class, - route_type + depends_class, route_type ): async def sub_dependency() -> None: pass - async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: + async def dependency_func( + param: Annotated[None, depends_class(sub_dependency)], + ) -> None: yield app = FastAPI() route_decorator = route_type(app, "/test") with pytest.raises(FastAPIError): + @route_decorator - async def endpoint(x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] + async def endpoint( + x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")], ) -> None: return + @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_dependencies_must_provide_correct_dependency_scope( - dependency_style: DependencyStyle, - routing_style, - use_cache + dependency_style: DependencyStyle, routing_style, use_cache ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -528,13 +536,17 @@ def test_dependencies_must_provide_correct_dependency_scope( router = APIRouter() with pytest.raises(FastAPIError): + @router.post("/test") async def endpoint( - dependency: Annotated[None, Depends( + dependency: Annotated[ + None, + Depends( dependency_factory.get_dependency(), dependency_scope="incorrect", use_cache=use_cache, - )] + ), + ], ) -> None: assert dependency == 1 return dependency @@ -544,11 +556,9 @@ def test_dependencies_must_provide_correct_dependency_scope( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoints_report_incorrect_dependency_scope( - dependency_style: DependencyStyle, - routing_style, - use_cache + dependency_style: DependencyStyle, routing_style, use_cache ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -567,10 +577,9 @@ def test_endpoints_report_incorrect_dependency_scope( depends.dependency_scope = "asdad" with pytest.raises(FastAPIError): + @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: + async def endpoint(dependency: Annotated[int, depends]) -> int: assert dependency == 1 return dependency @@ -579,11 +588,9 @@ def test_endpoints_report_incorrect_dependency_scope( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoints_report_uninitialized_dependency( - dependency_style: DependencyStyle, - routing_style, - use_cache + dependency_style: DependencyStyle, routing_style, use_cache ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -599,9 +606,7 @@ def test_endpoints_report_uninitialized_dependency( ) @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: + async def endpoint(dependency: Annotated[int, depends]) -> int: assert dependency == 1 return dependency @@ -616,18 +621,18 @@ def test_endpoints_report_uninitialized_dependency( with pytest.raises(FastAPIError): client.post("/test") finally: - client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies + client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = ( + dependencies + ) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoints_report_uninitialized_internal_lifespan( - dependency_style: DependencyStyle, - routing_style, - use_cache + dependency_style: DependencyStyle, routing_style, use_cache ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -643,9 +648,7 @@ def test_endpoints_report_uninitialized_internal_lifespan( ) @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: + async def endpoint(dependency: Annotated[int, depends]) -> int: assert dependency == 1 return dependency @@ -666,8 +669,10 @@ def test_endpoints_report_uninitialized_internal_lifespan( @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) -def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: DependencyStyle, routing_style): - dependency_factory= DependencyFactory(dependency_style, should_error=True) +def test_bad_lifespan_scoped_dependencies( + use_cache, dependency_style: DependencyStyle, routing_style +): + dependency_factory = DependencyFactory(dependency_style, should_error=True) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", @@ -683,9 +688,7 @@ def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: Dependenc router = APIRouter() @router.post("/test") - async def endpoint( - dependency: Annotated[int, depends] - ) -> int: + async def endpoint(dependency: Annotated[int, depends]) -> int: assert dependency == 1 return dependency diff --git a/tests/test_params_repr.py b/tests/test_params_repr.py index 10f044888..8921026b2 100644 --- a/tests/test_params_repr.py +++ b/tests/test_params_repr.py @@ -144,16 +144,30 @@ def test_body_repr_list(): assert repr(Body([])) == "Body([])" -@pytest.mark.parametrize(["depends", "expected_repr"], [ - [Depends(), "Depends(NoneType)"], - [Depends(get_user), "Depends(get_user)"], - [Depends(use_cache=False), "Depends(NoneType, use_cache=False)"], - [Depends(get_user, use_cache=False), "Depends(get_user, use_cache=False)"], - - [Depends(dependency_scope="lifespan"), "Depends(NoneType, dependency_scope=\"lifespan\")"], - [Depends(get_user, dependency_scope="lifespan"), "Depends(get_user, dependency_scope=\"lifespan\")"], - [Depends(use_cache=False, dependency_scope="lifespan"), "Depends(NoneType, use_cache=False, dependency_scope=\"lifespan\")"], - [Depends(get_user, use_cache=False, dependency_scope="lifespan"), "Depends(get_user, use_cache=False, dependency_scope=\"lifespan\")"], -]) +@pytest.mark.parametrize( + ["depends", "expected_repr"], + [ + [Depends(), "Depends(NoneType)"], + [Depends(get_user), "Depends(get_user)"], + [Depends(use_cache=False), "Depends(NoneType, use_cache=False)"], + [Depends(get_user, use_cache=False), "Depends(get_user, use_cache=False)"], + [ + Depends(dependency_scope="lifespan"), + 'Depends(NoneType, dependency_scope="lifespan")', + ], + [ + Depends(get_user, dependency_scope="lifespan"), + 'Depends(get_user, dependency_scope="lifespan")', + ], + [ + Depends(use_cache=False, dependency_scope="lifespan"), + 'Depends(NoneType, use_cache=False, dependency_scope="lifespan")', + ], + [ + Depends(get_user, use_cache=False, dependency_scope="lifespan"), + 'Depends(get_user, use_cache=False, dependency_scope="lifespan")', + ], + ], +) def test_depends_repr(depends, expected_repr): assert repr(depends) == expected_repr diff --git a/tests/test_router_events.py b/tests/test_router_events.py index 8289a7301..2f110e684 100644 --- a/tests/test_router_events.py +++ b/tests/test_router_events.py @@ -199,9 +199,7 @@ def test_router_nested_lifespan_state_overriding_by_parent() -> None: "app_specific": True, "router_specific": True, "overridden": "app", - "__fastapi__": { - "lifespan_scoped_dependencies": {} - }, + "__fastapi__": {"lifespan_scoped_dependencies": {}}, } @@ -219,11 +217,7 @@ def test_merged_no_return_lifespans_return_none() -> None: app.include_router(router) with TestClient(app) as client: - assert client.app_state == { - "__fastapi__": { - "lifespan_scoped_dependencies": {} - } - } + assert client.app_state == {"__fastapi__": {"lifespan_scoped_dependencies": {}}} def test_merged_mixed_state_lifespans() -> None: @@ -248,7 +242,5 @@ def test_merged_mixed_state_lifespans() -> None: with TestClient(app) as client: assert client.app_state == { "router": True, - "__fastapi__": { - "lifespan_scoped_dependencies": {} - } + "__fastapi__": {"lifespan_scoped_dependencies": {}}, }