diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 2ea813d48..9bbd81d37 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -12,7 +12,10 @@ class SecurityRequirement: scopes: Optional[Sequence[str]] = None -LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], Union[str, int]], Callable[..., Any]] +LifespanDependantCacheKey: TypeAlias = Union[ + Tuple[Callable[..., Any], Union[str, int]], Callable[..., Any] +] + @dataclass class LifespanDependant: @@ -30,9 +33,9 @@ class LifespanDependant: elif self.name is not None: self.cache_key = (self.caller, self.name) else: - assert self.index is not None, ( - "Lifespan dependency must have an associated name or index." - ) + assert ( + self.index is not None + ), "Lifespan dependency must have an associated name or index." self.cache_key = (self.caller, self.index) @@ -49,8 +52,7 @@ class EndpointDependant: call: Optional[Callable[..., Any]] = None use_cache: bool = True index: Optional[int] = None - cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field( - init=False) + cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) path_params: List[ModelField] = field(default_factory=list) query_params: List[ModelField] = field(default_factory=list) header_params: List[ModelField] = field(default_factory=list) @@ -74,11 +76,11 @@ class EndpointDependant: def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]: lifespan_dependencies = cast( List[Union[EndpointDependant, LifespanDependant]], - self.lifespan_dependencies + self.lifespan_dependencies, ) endpoint_dependencies = cast( List[Union[EndpointDependant, LifespanDependant]], - self.endpoint_dependencies + self.endpoint_dependencies, ) return tuple(lifespan_dependencies + endpoint_dependencies) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 96f4a910f..96a00c12a 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -147,11 +147,7 @@ def get_param_sub_dependant( def get_parameterless_sub_dependant( - *, - depends: params.Depends, - path: str, - caller: Callable[..., Any], - index: int + *, depends: params.Depends, path: str, caller: Callable[..., Any], index: int ) -> Union[EndpointDependant, LifespanDependant]: assert callable( depends.dependency @@ -161,7 +157,7 @@ def get_parameterless_sub_dependant( dependency=depends.dependency, path=path, caller=caller, - index=index + index=index, ) @@ -181,7 +177,7 @@ def get_sub_dependant( call=dependency, name=name, use_cache=depends.use_cache, - index=index + index=index, ) elif depends.dependency_scope == "endpoint": security_requirement = None @@ -202,15 +198,15 @@ def get_sub_dependant( name=name, security_scopes=security_scopes, use_cache=depends.use_cache, - index=index + index=index, ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) return sub_dependant else: raise InvalidDependencyScope( - f"Dependency \"{name}\" of {caller} has an invalid " - f"scope: \"{depends.dependency_scope}\"" + f'Dependency "{name}" of {caller} has an invalid ' + f'scope: "{depends.dependency_scope}"' ) @@ -233,7 +229,7 @@ def get_flat_dependant( security_requirements=dependant.security_requirements.copy(), lifespan_dependencies=dependant.lifespan_dependencies.copy(), use_cache=dependant.use_cache, - path=dependant.path + path=dependant.path, ) for sub_dependant in dependant.endpoint_dependencies: if skip_repeats and sub_dependant.cache_key in visited: @@ -310,16 +306,12 @@ def get_lifespan_dependant( call: Callable[..., Any], name: Optional[str] = None, use_cache: bool = True, - index: Optional[int] = None + index: Optional[int] = None, ) -> LifespanDependant: dependency_signature = get_typed_signature(call) signature_params = dependency_signature.parameters dependant = LifespanDependant( - call=call, - name=name, - use_cache=use_cache, - caller=caller, - index=index + call=call, name=name, use_cache=use_cache, caller=caller, index=index ) for param_name, param in signature_params.items(): param_details = analyze_param( @@ -330,16 +322,17 @@ def get_lifespan_dependant( ) if param_details.depends is None: raise DependencyScopeConflict( - f"Lifespan scoped dependency \"{dependant.name}\" was defined " - f"with an invalid argument: \"{param_name}\" which is " - f"\"endpoint\" scoped. Lifespan scoped dependencies may only " - f"use lifespan scoped sub-dependencies.") + f'Lifespan scoped dependency "{dependant.name}" was defined ' + f'with an invalid argument: "{param_name}" which is ' + f'"endpoint" scoped. Lifespan scoped dependencies may only ' + f"use lifespan scoped sub-dependencies." + ) if param_details.depends.dependency_scope != "lifespan": raise DependencyScopeConflict( f"Lifespan scoped dependency {dependant.name} was defined with the " - f"sub-dependency \"{param_name}\" which is " - f"\"{param_details.depends.dependency_scope}\" scoped. " + f'sub-dependency "{param_name}" which is ' + f'"{param_details.depends.dependency_scope}" scoped. ' f"Lifespan scoped dependencies may only use lifespan scoped " f"sub-dependencies." ) @@ -350,7 +343,7 @@ def get_lifespan_dependant( name=param_name, call=param_details.depends.dependency, use_cache=param_details.depends.use_cache, - caller=call + caller=call, ) dependant.dependencies.append(sub_dependant) @@ -364,7 +357,7 @@ def get_endpoint_dependant( name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, - index: Optional[int] = None + index: Optional[int] = None, ) -> EndpointDependant: path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) @@ -375,7 +368,7 @@ def get_endpoint_dependant( path=path, security_scopes=security_scopes, use_cache=use_cache, - index=index + index=index, ) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names @@ -692,14 +685,12 @@ async def solve_lifespan_dependant( call = dependant.call dependant_to_solve = dependant if ( - dependency_overrides_provider - and dependency_overrides_provider.dependency_overrides + dependency_overrides_provider + and dependency_overrides_provider.dependency_overrides ): - call = getattr( - dependency_overrides_provider, - "dependency_overrides", - {} - ).get(dependant.call, dependant.call) + call = getattr(dependency_overrides_provider, "dependency_overrides", {}).get( + dependant.call, dependant.call + ) dependant_to_solve = get_lifespan_dependant( caller=dependant.caller, call=call, @@ -725,9 +716,7 @@ async def solve_lifespan_dependant( if is_gen_callable(call) or is_async_gen_callable(call): value = await solve_generator( - call=call, - stack=async_exit_stack, - sub_values=dependency_arguments + call=call, stack=async_exit_stack, sub_values=dependency_arguments ) elif is_coroutine_callable(call): value = await call(**dependency_arguments) @@ -773,7 +762,8 @@ async def solve_dependencies( try: lifespan_scoped_dependencies = request.state.__fastapi__[ - "lifespan_scoped_dependencies"] + "lifespan_scoped_dependencies" + ] except (AttributeError, KeyError) as e: raise UninitializedLifespanDependency( "FastAPI's internal lifespan was not initialized correctly." @@ -783,8 +773,8 @@ async def solve_dependencies( value = lifespan_scoped_dependencies[lifespan_sub_dependant.cache_key] except KeyError as e: raise UninitializedLifespanDependency( - f"Dependency \"{lifespan_sub_dependant.name}\" of " - f"`{dependant.call}` was not initialized correctly." + f'Dependency "{lifespan_sub_dependant.name}" of ' + f"`{dependant.call}` was not initialized correctly." ) from e values[lifespan_sub_dependant.name] = value diff --git a/fastapi/routing.py b/fastapi/routing.py index e11edaa13..376ad9c8b 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -400,13 +400,12 @@ class APIWebSocketRoute(routing.WebSocketRoute): self.name = get_name(endpoint) if name is None else name self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) - self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_endpoint_dependant( + path=self.path_format, call=self.endpoint + ) for i, depends in list(enumerate(self.dependencies))[::-1]: sub_dependant = get_parameterless_sub_dependant( - depends=depends, - path=self.path_format, - caller=self.__call__, - index=i + depends=depends, path=self.path_format, caller=self.__call__, index=i ) if depends.dependency_scope == "endpoint": assert isinstance(sub_dependant, EndpointDependant) @@ -566,13 +565,12 @@ class APIRoute(routing.Route): self.response_fields = {} assert callable(endpoint), "An endpoint must be a callable" - self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_endpoint_dependant( + path=self.path_format, call=self.endpoint + ) for i, depends in list(enumerate(self.dependencies))[::-1]: sub_dependant = get_parameterless_sub_dependant( - depends=depends, - path=self.path_format, - caller=self.__call__, - index=i + depends=depends, path=self.path_format, caller=self.__call__, index=i ) if depends.dependency_scope == "endpoint": assert isinstance(sub_dependant, EndpointDependant) diff --git a/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py index 430765aae..61d2fe3b2 100644 --- a/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py +++ b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py @@ -33,13 +33,13 @@ from tests.test_lifespan_scoped_dependencies.testing_utilities import ( def expect_correct_amount_of_dependency_activations( - *, - app: FastAPI, - dependency_factory: DependencyFactory, - override_dependency_factory: DependencyFactory, - urls_and_responses: List[Tuple[str, Any]], - expected_activation_times: int, - is_websocket: bool + *, + app: FastAPI, + dependency_factory: DependencyFactory, + override_dependency_factory: DependencyFactory, + urls_and_responses: List[Tuple[str, Any]], + expected_activation_times: int, + is_websocket: bool, ) -> None: assert dependency_factory.activation_times == 0 assert dependency_factory.deactivation_times == 0 @@ -62,17 +62,22 @@ def expect_correct_amount_of_dependency_activations( assert dependency_factory.activation_times == 0 assert dependency_factory.deactivation_times == 0 - assert override_dependency_factory.activation_times == expected_activation_times + assert ( + override_dependency_factory.activation_times + == expected_activation_times + ) assert override_dependency_factory.deactivation_times == 0 assert dependency_factory.activation_times == 0 assert override_dependency_factory.activation_times == expected_activation_times if dependency_factory.dependency_style not in ( - DependencyStyle.SYNC_FUNCTION, - DependencyStyle.ASYNC_FUNCTION + DependencyStyle.SYNC_FUNCTION, + DependencyStyle.ASYNC_FUNCTION, ): assert dependency_factory.deactivation_times == 0 - assert override_dependency_factory.deactivation_times == expected_activation_times + assert ( + override_dependency_factory.deactivation_times == expected_activation_times + ) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @@ -80,16 +85,10 @@ def expect_correct_amount_of_dependency_activations( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoint_dependencies( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): dependency_factory = DependencyFactory(dependency_style) - override_dependency_factory = DependencyFactory( - dependency_style, - value_offset=10 - ) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) app = FastAPI() @@ -108,14 +107,16 @@ def test_endpoint_dependencies( dependency_factory.get_dependency(), dependency_scope="lifespan", use_cache=use_cache, - ) + ), ], - expected_value=11 + expected_value=11, ) if routing_style == "router_endpoint": app.include_router(router) - app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) expect_correct_amount_of_dependency_activations( app=app, @@ -123,7 +124,7 @@ def test_endpoint_dependencies( override_dependency_factory=override_dependency_factory, urls_and_responses=[("/test", 11)] * 2, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) @@ -133,45 +134,40 @@ def test_endpoint_dependencies( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_router_dependencies( - dependency_style: DependencyStyle, - routing_style, - use_cache, - dependency_duplication, - is_websocket + dependency_style: DependencyStyle, + routing_style, + use_cache, + dependency_duplication, + is_websocket, ): - dependency_factory= DependencyFactory(dependency_style) - override_dependency_factory = DependencyFactory( - dependency_style, - value_offset=10 - ) + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) if routing_style == "app": app = FastAPI(dependencies=[depends] * dependency_duplication) create_endpoint_0_annotations( - router=app, - path="/test", - is_websocket=is_websocket + router=app, path="/test", is_websocket=is_websocket ) else: app = FastAPI() router = APIRouter(dependencies=[depends] * dependency_duplication) create_endpoint_0_annotations( - router=router, - path="/test", - is_websocket=is_websocket + router=router, path="/test", is_websocket=is_websocket ) app.include_router(router) - app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) expect_correct_amount_of_dependency_activations( app=app, @@ -179,31 +175,29 @@ def test_router_dependencies( override_dependency_factory=override_dependency_factory, urls_and_responses=[("/test", None)] * 2, expected_activation_times=1 if use_cache else dependency_duplication, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) def test_dependency_cache_in_same_dependency( - dependency_style: DependencyStyle, - routing_style, - use_cache, - main_dependency_scope: Literal["endpoint", "lifespan"], - is_websocket + dependency_style: DependencyStyle, + routing_style, + use_cache, + main_dependency_scope: Literal["endpoint", "lifespan"], + is_websocket, ): - dependency_factory= DependencyFactory(dependency_style) - override_dependency_factory = DependencyFactory( - dependency_style, - value_offset=10 - ) + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -215,8 +209,8 @@ def test_dependency_cache_in_same_dependency( router = APIRouter() async def dependency( - sub_dependency1: Annotated[int, depends], - sub_dependency2: Annotated[int, depends], + sub_dependency1: Annotated[int, depends], + sub_dependency2: Annotated[int, depends], ) -> List[int]: return [sub_dependency1, sub_dependency2] @@ -224,19 +218,22 @@ def test_dependency_cache_in_same_dependency( router=router, path="/test", is_websocket=is_websocket, - annotation=Annotated[List[int], Depends( - dependency, - use_cache=use_cache, - dependency_scope=main_dependency_scope, - )] + annotation=Annotated[ + List[int], + Depends( + dependency, + use_cache=use_cache, + dependency_scope=main_dependency_scope, + ), + ], ) if routing_style == "router": app.include_router(router) - app.dependency_overrides[ - dependency_factory.get_dependency() - ] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) if use_cache: expect_correct_amount_of_dependency_activations( @@ -248,7 +245,7 @@ def test_dependency_cache_in_same_dependency( dependency_factory=dependency_factory, override_dependency_factory=override_dependency_factory, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) else: expect_correct_amount_of_dependency_activations( @@ -260,7 +257,7 @@ def test_dependency_cache_in_same_dependency( dependency_factory=dependency_factory, override_dependency_factory=override_dependency_factory, expected_activation_times=2, - is_websocket=is_websocket + is_websocket=is_websocket, ) @@ -269,21 +266,15 @@ def test_dependency_cache_in_same_dependency( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_dependency_cache_in_same_endpoint( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) - override_dependency_factory = DependencyFactory( - dependency_style, - value_offset=10 - ) + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -309,9 +300,9 @@ def test_dependency_cache_in_same_endpoint( if routing_style == "router": app.include_router(router) - app.dependency_overrides[ - dependency_factory.get_dependency() - ] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) if use_cache: expect_correct_amount_of_dependency_activations( @@ -323,7 +314,7 @@ def test_dependency_cache_in_same_endpoint( dependency_factory=dependency_factory, override_dependency_factory=override_dependency_factory, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) else: expect_correct_amount_of_dependency_activations( @@ -335,29 +326,24 @@ def test_dependency_cache_in_same_endpoint( dependency_factory=dependency_factory, override_dependency_factory=override_dependency_factory, expected_activation_times=3, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_dependency_cache_in_different_endpoints( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) - override_dependency_factory = DependencyFactory( - dependency_style, - value_offset=10 - ) + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -392,8 +378,9 @@ def test_dependency_cache_in_different_endpoints( if routing_style == "router": app.include_router(router) - app.dependency_overrides[ - dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) if use_cache: expect_correct_amount_of_dependency_activations( @@ -407,7 +394,7 @@ def test_dependency_cache_in_different_endpoints( dependency_factory=dependency_factory, override_dependency_factory=override_dependency_factory, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) else: expect_correct_amount_of_dependency_activations( @@ -421,27 +408,23 @@ def test_dependency_cache_in_different_endpoints( dependency_factory=dependency_factory, override_dependency_factory=override_dependency_factory, expected_activation_times=5, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_no_cached_dependency( - dependency_style: DependencyStyle, - routing_style, - is_websocket + dependency_style: DependencyStyle, routing_style, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) - override_dependency_factory = DependencyFactory( - dependency_style, - value_offset=10 - ) + dependency_factory = DependencyFactory(dependency_style) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=False + use_cache=False, ) app = FastAPI() @@ -462,8 +445,9 @@ def test_no_cached_dependency( if routing_style == "router": app.include_router(router) - app.dependency_overrides[ - dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) expect_correct_amount_of_dependency_activations( app=app, @@ -471,24 +455,27 @@ def test_no_cached_dependency( override_dependency_factory=override_dependency_factory, urls_and_responses=[("/test", 11)] * 2, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) -@pytest.mark.parametrize("annotation", [ - Annotated[str, Path()], - Annotated[str, Body()], - Annotated[str, Query()], - Annotated[str, Header()], - SecurityScopes, - Annotated[str, Cookie()], - Annotated[str, Form()], - Annotated[str, File()], - BackgroundTasks, -]) +@pytest.mark.parametrize( + "annotation", + [ + Annotated[str, Path()], + Annotated[str, Body()], + Annotated[str, Query()], + Annotated[str, Header()], + SecurityScopes, + Annotated[str, Cookie()], + Annotated[str, Form()], + Annotated[str, File()], + BackgroundTasks, + ], +) def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( - annotation, - is_websocket + annotation, is_websocket ): async def dependency_func() -> None: yield @@ -503,9 +490,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_paramete router=app, path="/test", is_websocket=is_websocket, - annotation=Annotated[None, - Depends(dependency_func, dependency_scope="lifespan") - ] + annotation=Annotated[ + None, Depends(dependency_func, dependency_scope="lifespan") + ], ) with pytest.raises(DependencyScopeConflict): @@ -516,20 +503,16 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_paramete @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies( - dependency_style: DependencyStyle, - is_websocket + dependency_style: DependencyStyle, is_websocket ): dependency_factory = DependencyFactory(dependency_style) - override_dependency_factory = DependencyFactory( - dependency_style, - value_offset=10 - ) + override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) async def lifespan_scoped_dependency( - param: Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - )] + param: Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], ) -> AsyncGenerator[int, None]: yield param @@ -540,12 +523,13 @@ def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_sco path="/test", is_websocket=is_websocket, annotation=Annotated[ - int, - Depends(lifespan_scoped_dependency, dependency_scope="lifespan") + int, Depends(lifespan_scoped_dependency, dependency_scope="lifespan") ], ) - app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) expect_correct_amount_of_dependency_activations( app=app, @@ -553,15 +537,14 @@ def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_sco override_dependency_factory=override_dependency_factory, expected_activation_times=1, urls_and_responses=[("/test", 11)] * 2, - is_websocket=is_websocket + is_websocket=is_websocket, ) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("depends_class", [Depends, Security]) def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( - depends_class, - is_websocket + depends_class, is_websocket ): async def sub_dependency() -> None: pass @@ -569,7 +552,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen async def dependency_func() -> None: yield - async def override_dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: + async def override_dependency_func( + param: Annotated[None, depends_class(sub_dependency)], + ) -> None: yield app = FastAPI() @@ -578,7 +563,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen router=app, path="/test", is_websocket=is_websocket, - annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] + annotation=Annotated[ + None, Depends(dependency_func, dependency_scope="lifespan") + ], ) app.dependency_overrides[dependency_func] = override_dependency_func @@ -593,12 +580,9 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_bad_override_lifespan_scoped_dependencies( - use_cache, - dependency_style: DependencyStyle, - routing_style, - is_websocket + use_cache, dependency_style: DependencyStyle, routing_style, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) override_dependency_factory = DependencyFactory(dependency_style, should_error=True) depends = Depends( @@ -619,13 +603,15 @@ def test_bad_override_lifespan_scoped_dependencies( router=router, path="/test", is_websocket=is_websocket, - annotation=Annotated[int, depends] + annotation=Annotated[int, depends], ) if routing_style == "router_endpoint": app.include_router(router) - app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() + app.dependency_overrides[dependency_factory.get_dependency()] = ( + override_dependency_factory.get_dependency() + ) with pytest.raises(IntentionallyBadDependency) as exception_info: with TestClient(app): diff --git a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py index ccf8d896a..66caf065a 100644 --- a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py +++ b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py @@ -40,12 +40,12 @@ from tests.test_lifespan_scoped_dependencies.testing_utilities import ( def expect_correct_amount_of_dependency_activations( - *, - app: FastAPI, - dependency_factory: DependencyFactory, - urls_and_responses: List[Tuple[str, Any]], - expected_activation_times: int, - is_websocket: bool + *, + app: FastAPI, + dependency_factory: DependencyFactory, + urls_and_responses: List[Tuple[str, Any]], + expected_activation_times: int, + is_websocket: bool, ) -> None: assert dependency_factory.activation_times == 0 assert dependency_factory.deactivation_times == 0 @@ -64,21 +64,23 @@ def expect_correct_amount_of_dependency_activations( assert dependency_factory.activation_times == expected_activation_times if dependency_factory.dependency_style not in ( - DependencyStyle.SYNC_FUNCTION, - DependencyStyle.ASYNC_FUNCTION + DependencyStyle.SYNC_FUNCTION, + DependencyStyle.ASYNC_FUNCTION, ): assert dependency_factory.deactivation_times == expected_activation_times @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) -@pytest.mark.parametrize("use_cache", [True, False], ids=["With Cache", "Without Cache"]) +@pytest.mark.parametrize( + "use_cache", [True, False], ids=["With Cache", "Without Cache"] +) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoint_dependencies( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket: bool, + dependency_style: DependencyStyle, + routing_style, + use_cache, + is_websocket: bool, ): dependency_factory = DependencyFactory(dependency_style) @@ -93,12 +95,15 @@ def test_endpoint_dependencies( router=router, path="/test", is_websocket=is_websocket, - annotation=Annotated[None, Depends( + annotation=Annotated[ + None, + Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", use_cache=use_cache, - )], - expected_value=1 + ), + ], + expected_value=1, ) if routing_style == "router_endpoint": @@ -109,45 +114,42 @@ def test_endpoint_dependencies( dependency_factory=dependency_factory, urls_and_responses=[("/test", 1)] * 2, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("dependency_duplication", [1, 2]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_router_dependencies( - dependency_style: DependencyStyle, - routing_style, - use_cache, - dependency_duplication, - is_websocket: bool, + dependency_style: DependencyStyle, + routing_style, + use_cache, + dependency_duplication, + is_websocket: bool, ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) if routing_style == "app": app = FastAPI(dependencies=[depends] * dependency_duplication) create_endpoint_0_annotations( - router=app, - path="/test", - is_websocket=is_websocket + router=app, path="/test", is_websocket=is_websocket ) else: app = FastAPI() router = APIRouter(dependencies=[depends] * dependency_duplication) create_endpoint_0_annotations( - router=router, - path="/test", - is_websocket=is_websocket + router=router, path="/test", is_websocket=is_websocket ) app.include_router(router) @@ -157,27 +159,28 @@ def test_router_dependencies( dependency_factory=dependency_factory, urls_and_responses=[("/test", None)] * 2, expected_activation_times=1 if use_cache else dependency_duplication, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) @pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) def test_dependency_cache_in_same_dependency( - dependency_style: DependencyStyle, - routing_style, - use_cache, - main_dependency_scope: Literal["endpoint", "lifespan"], - is_websocket: bool, + dependency_style: DependencyStyle, + routing_style, + use_cache, + main_dependency_scope: Literal["endpoint", "lifespan"], + is_websocket: bool, ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -189,8 +192,8 @@ def test_dependency_cache_in_same_dependency( router = APIRouter() async def dependency( - sub_dependency1: Annotated[int, depends], - sub_dependency2: Annotated[int, depends], + sub_dependency1: Annotated[int, depends], + sub_dependency2: Annotated[int, depends], ) -> List[int]: return [sub_dependency1, sub_dependency2] @@ -198,11 +201,14 @@ def test_dependency_cache_in_same_dependency( router=router, path="/test", is_websocket=is_websocket, - annotation=Annotated[List[int], Depends( - dependency, - use_cache=use_cache, - dependency_scope=main_dependency_scope, - )] + annotation=Annotated[ + List[int], + Depends( + dependency, + use_cache=use_cache, + dependency_scope=main_dependency_scope, + ), + ], ) if routing_style == "router": @@ -217,7 +223,7 @@ def test_dependency_cache_in_same_dependency( ], dependency_factory=dependency_factory, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) else: expect_correct_amount_of_dependency_activations( @@ -228,7 +234,7 @@ def test_dependency_cache_in_same_dependency( ], dependency_factory=dependency_factory, expected_activation_times=2, - is_websocket=is_websocket + is_websocket=is_websocket, ) @@ -237,17 +243,14 @@ def test_dependency_cache_in_same_dependency( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_dependency_cache_in_same_endpoint( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -267,7 +270,7 @@ def test_dependency_cache_in_same_endpoint( is_websocket=is_websocket, annotation1=Annotated[int, depends], annotation2=Annotated[int, depends], - annotation3=Annotated[int, Depends(endpoint_dependency)] + annotation3=Annotated[int, Depends(endpoint_dependency)], ) if routing_style == "router": @@ -282,7 +285,7 @@ def test_dependency_cache_in_same_endpoint( ], dependency_factory=dependency_factory, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) else: expect_correct_amount_of_dependency_activations( @@ -293,25 +296,23 @@ def test_dependency_cache_in_same_endpoint( ], dependency_factory=dependency_factory, expected_activation_times=3, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_dependency_cache_in_different_endpoints( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=use_cache + use_cache=use_cache, ) app = FastAPI() @@ -331,7 +332,7 @@ def test_dependency_cache_in_different_endpoints( is_websocket=is_websocket, annotation1=Annotated[int, depends], annotation2=Annotated[int, depends], - annotation3=Annotated[int, Depends(endpoint_dependency)] + annotation3=Annotated[int, Depends(endpoint_dependency)], ) create_endpoint_3_annotations( @@ -340,7 +341,7 @@ def test_dependency_cache_in_different_endpoints( is_websocket=is_websocket, annotation1=Annotated[int, depends], annotation2=Annotated[int, depends], - annotation3=Annotated[int, Depends(endpoint_dependency)] + annotation3=Annotated[int, Depends(endpoint_dependency)], ) if routing_style == "router": @@ -357,7 +358,7 @@ def test_dependency_cache_in_different_endpoints( ], dependency_factory=dependency_factory, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) else: expect_correct_amount_of_dependency_activations( @@ -370,23 +371,24 @@ def test_dependency_cache_in_different_endpoints( ], dependency_factory=dependency_factory, expected_activation_times=5, - is_websocket=is_websocket + is_websocket=is_websocket, ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app", "router"]) def test_no_cached_dependency( - dependency_style: DependencyStyle, - routing_style, - is_websocket, + dependency_style: DependencyStyle, + routing_style, + is_websocket, ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", - use_cache=False + use_cache=False, ) app = FastAPI() @@ -402,7 +404,7 @@ def test_no_cached_dependency( path="/test", is_websocket=is_websocket, annotation=Annotated[int, depends], - expected_value=1 + expected_value=1, ) if routing_style == "router": @@ -413,25 +415,27 @@ def test_no_cached_dependency( dependency_factory=dependency_factory, urls_and_responses=[("/test", 1)] * 2, expected_activation_times=1, - is_websocket=is_websocket + is_websocket=is_websocket, ) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) -@pytest.mark.parametrize("annotation", [ - Annotated[str, Path()], - Annotated[str, Body()], - Annotated[str, Query()], - Annotated[str, Header()], - SecurityScopes, - Annotated[str, Cookie()], - Annotated[str, Form()], - Annotated[str, File()], - BackgroundTasks, -]) +@pytest.mark.parametrize( + "annotation", + [ + Annotated[str, Path()], + Annotated[str, Body()], + Annotated[str, Query()], + Annotated[str, Header()], + SecurityScopes, + Annotated[str, Cookie()], + Annotated[str, Form()], + Annotated[str, File()], + BackgroundTasks, + ], +) def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( - annotation, - is_websocket + annotation, is_websocket ): async def dependency_func(param: annotation) -> None: yield @@ -444,8 +448,7 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( path="/test", is_websocket=is_websocket, annotation=Annotated[ - None, - Depends(dependency_func, dependency_scope="lifespan") + None, Depends(dependency_func, dependency_scope="lifespan") ], ) @@ -453,16 +456,15 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( - dependency_style: DependencyStyle, - is_websocket + dependency_style: DependencyStyle, is_websocket ): dependency_factory = DependencyFactory(dependency_style) async def lifespan_scoped_dependency( - param: Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - )] + param: Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], ) -> AsyncGenerator[int, None]: yield param @@ -473,7 +475,7 @@ def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( path="/test", is_websocket=is_websocket, annotation=Annotated[int, Depends(lifespan_scoped_dependency)], - expected_value=1 + expected_value=1, ) expect_correct_amount_of_dependency_activations( @@ -481,24 +483,22 @@ def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( dependency_factory=dependency_factory, expected_activation_times=1, urls_and_responses=[("/test", 1)] * 2, - is_websocket=is_websocket + is_websocket=is_websocket, ) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) -@pytest.mark.parametrize([ - "dependency_style", - "supports_teardown" -], [ - (DependencyStyle.SYNC_FUNCTION, False), - (DependencyStyle.ASYNC_FUNCTION, False), - (DependencyStyle.SYNC_GENERATOR, True), - (DependencyStyle.ASYNC_GENERATOR, True), -]) +@pytest.mark.parametrize( + ["dependency_style", "supports_teardown"], + [ + (DependencyStyle.SYNC_FUNCTION, False), + (DependencyStyle.ASYNC_FUNCTION, False), + (DependencyStyle.SYNC_GENERATOR, True), + (DependencyStyle.ASYNC_GENERATOR, True), + ], +) def test_the_same_dependency_can_work_in_different_scopes( - dependency_style: DependencyStyle, - supports_teardown, - is_websocket + dependency_style: DependencyStyle, supports_teardown, is_websocket ): dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -507,14 +507,14 @@ def test_the_same_dependency_can_work_in_different_scopes( router=app, path="/test", is_websocket=is_websocket, - annotation1=Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="endpoint" - )], - annotation2=Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - )], + annotation1=Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="endpoint"), + ], + annotation2=Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], ) if is_websocket: get_response = use_websocket @@ -548,17 +548,20 @@ def test_the_same_dependency_can_work_in_different_scopes( assert dependency_factory.deactivation_times == 0 -@pytest.mark.parametrize("lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"]) +@pytest.mark.parametrize( + "lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"] +) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( - dependency_style: DependencyStyle, - is_websocket, - lifespan_style: Literal["lifespan_function", "lifespan_events"] + dependency_style: DependencyStyle, + is_websocket, + lifespan_style: Literal["lifespan_function", "lifespan_events"], ): lifespan_started = False lifespan_ended = False if lifespan_style == "lifespan_generator": + @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, int], None]: nonlocal lifespan_started @@ -571,6 +574,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( elif lifespan_style == "events_decorator": app = FastAPI() with warnings.catch_warnings(action="ignore", category=DeprecationWarning): + @app.on_event("startup") async def startup() -> None: nonlocal lifespan_started @@ -581,6 +585,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( nonlocal lifespan_ended lifespan_ended = True elif lifespan_style == "events_constructor": + async def startup() -> None: nonlocal lifespan_started lifespan_started = True @@ -588,6 +593,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( async def shutdown() -> None: nonlocal lifespan_ended lifespan_ended = True + app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) else: assert_never(lifespan_style) @@ -598,11 +604,11 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( router=app, path="/test", is_websocket=is_websocket, - annotation=Annotated[int, Depends( - dependency_factory.get_dependency(), - dependency_scope="lifespan" - )], - expected_value=1 + annotation=Annotated[ + int, + Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), + ], + expected_value=1, ) expect_correct_amount_of_dependency_activations( @@ -610,20 +616,22 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( dependency_factory=dependency_factory, expected_activation_times=1, urls_and_responses=[("/test", 1)] * 2, - is_websocket=is_websocket + is_websocket=is_websocket, ) assert lifespan_started and lifespan_ended + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("depends_class", [Depends, Security]) def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( - depends_class, - is_websocket + depends_class, is_websocket ): async def sub_dependency() -> None: pass - async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: + async def dependency_func( + param: Annotated[None, depends_class(sub_dependency)], + ) -> None: yield app = FastAPI() @@ -633,20 +641,20 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( router=app, path="/test", is_websocket=is_websocket, - annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")], + annotation=Annotated[ + None, Depends(dependency_func, dependency_scope="lifespan") + ], ) + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_dependencies_must_provide_correct_dependency_scope( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -656,19 +664,21 @@ def test_dependencies_must_provide_correct_dependency_scope( router = APIRouter() with pytest.raises( - InvalidDependencyScope, - match=r'Dependency "value" of .* has an invalid scope: ' - r'"incorrect"' + InvalidDependencyScope, + match=r'Dependency "value" of .* has an invalid scope: ' r'"incorrect"', ): create_endpoint_1_annotation( router=router, path="/test", is_websocket=is_websocket, - annotation=Annotated[None, Depends( - dependency_factory.get_dependency(), - dependency_scope="incorrect", - use_cache=use_cache, - )] + annotation=Annotated[ + None, + Depends( + dependency_factory.get_dependency(), + dependency_scope="incorrect", + use_cache=use_cache, + ), + ], ) @@ -677,12 +687,9 @@ def test_dependencies_must_provide_correct_dependency_scope( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoints_report_incorrect_dependency_scope( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -705,7 +712,7 @@ def test_endpoints_report_incorrect_dependency_scope( router=router, path="/test", is_websocket=is_websocket, - annotation=Annotated[int, depends] + annotation=Annotated[int, depends], ) @@ -714,12 +721,9 @@ def test_endpoints_report_incorrect_dependency_scope( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoints_report_uninitialized_dependency( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): - dependency_factory= DependencyFactory(dependency_style) + dependency_factory = DependencyFactory(dependency_style) app = FastAPI() @@ -739,7 +743,7 @@ def test_endpoints_report_uninitialized_dependency( path="/test", is_websocket=is_websocket, annotation=Annotated[int, depends], - expected_value=1 + expected_value=1, ) if routing_style == "router_endpoint": @@ -757,7 +761,9 @@ def test_endpoints_report_uninitialized_dependency( else: client.post("/test") finally: - client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies + client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = ( + dependencies + ) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) @@ -765,10 +771,7 @@ def test_endpoints_report_uninitialized_dependency( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_endpoints_report_uninitialized_internal_lifespan( - dependency_style: DependencyStyle, - routing_style, - use_cache, - is_websocket + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket ): dependency_factory = DependencyFactory(dependency_style) @@ -790,7 +793,7 @@ def test_endpoints_report_uninitialized_internal_lifespan( path="/test", is_websocket=is_websocket, annotation=Annotated[int, depends], - expected_value=1 + expected_value=1, ) if routing_style == "router_endpoint": @@ -816,12 +819,9 @@ def test_endpoints_report_uninitialized_internal_lifespan( @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) def test_bad_lifespan_scoped_dependencies( - use_cache, - dependency_style: DependencyStyle, - routing_style, - is_websocket + use_cache, dependency_style: DependencyStyle, routing_style, is_websocket ): - dependency_factory= DependencyFactory(dependency_style, should_error=True) + dependency_factory = DependencyFactory(dependency_style, should_error=True) depends = Depends( dependency_factory.get_dependency(), dependency_scope="lifespan", @@ -841,7 +841,7 @@ def test_bad_lifespan_scoped_dependencies( path="/test", is_websocket=is_websocket, annotation=Annotated[int, depends], - expected_value=1 + expected_value=1, ) if routing_style == "router_endpoint": diff --git a/tests/test_lifespan_scoped_dependencies/testing_utilities.py b/tests/test_lifespan_scoped_dependencies/testing_utilities.py index e733205f5..1f0f100a8 100644 --- a/tests/test_lifespan_scoped_dependencies/testing_utilities.py +++ b/tests/test_lifespan_scoped_dependencies/testing_utilities.py @@ -5,7 +5,7 @@ from fastapi import APIRouter, FastAPI, WebSocket from starlette.testclient import TestClient from starlette.websockets import WebSocketDisconnect -T = TypeVar('T') +T = TypeVar("T") class DependencyStyle(StrEnum): @@ -21,10 +21,11 @@ class IntentionallyBadDependency(Exception): class DependencyFactory: def __init__( - self, - dependency_style: DependencyStyle, *, - should_error: bool = False, - value_offset: int = 0, + self, + dependency_style: DependencyStyle, + *, + should_error: bool = False, + value_offset: int = 0, ): self.activation_times = 0 self.deactivation_times = 0 @@ -90,12 +91,13 @@ def use_websocket(client: TestClient, url: str) -> Any: def create_endpoint_0_annotations( - *, - router: Union[APIRouter, FastAPI], - path: str, - is_websocket: bool, + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, ) -> None: if is_websocket: + @router.websocket(path) async def endpoint(websocket: WebSocket) -> None: await websocket.accept() @@ -104,25 +106,24 @@ def create_endpoint_0_annotations( except WebSocketDisconnect: pass else: + @router.post(path) async def endpoint() -> None: return None def create_endpoint_1_annotation( - *, - router: Union[APIRouter, FastAPI], - path: str, - is_websocket: bool, - annotation: Any, - expected_value: Any = None + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, + annotation: Any, + expected_value: Any = None, ) -> None: if is_websocket: + @router.websocket(path) - async def endpoint( - websocket: WebSocket, - value: annotation - ) -> None: + async def endpoint(websocket: WebSocket, value: annotation) -> None: if expected_value is not None: assert value == expected_value @@ -132,29 +133,30 @@ def create_endpoint_1_annotation( except WebSocketDisconnect: pass else: + @router.post(path) - async def endpoint( - value: annotation - ) -> None: + async def endpoint(value: annotation) -> None: if expected_value is not None: assert value == expected_value return value + def create_endpoint_2_annotations( - *, - router: Union[APIRouter, FastAPI], - path: str, - is_websocket: bool, - annotation1: Any, - annotation2: Any, + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, + annotation1: Any, + annotation2: Any, ) -> None: if is_websocket: + @router.websocket(path) async def endpoint( - websocket: WebSocket, - value1: annotation1, - value2: annotation2, + websocket: WebSocket, + value1: annotation1, + value2: annotation2, ) -> None: await websocket.accept() try: @@ -162,30 +164,32 @@ def create_endpoint_2_annotations( except WebSocketDisconnect: await websocket.close() else: + @router.post(path) async def endpoint( - value1: annotation1, - value2: annotation2, + value1: annotation1, + value2: annotation2, ) -> list[Any]: return [value1, value2] def create_endpoint_3_annotations( - *, - router: Union[APIRouter, FastAPI], - path: str, - is_websocket: bool, - annotation1: Any, - annotation2: Any, - annotation3: Any + *, + router: Union[APIRouter, FastAPI], + path: str, + is_websocket: bool, + annotation1: Any, + annotation2: Any, + annotation3: Any, ) -> None: if is_websocket: + @router.websocket(path) async def endpoint( - websocket: WebSocket, - value1: annotation1, - value2: annotation2, - value3: annotation3 + websocket: WebSocket, + value1: annotation1, + value2: annotation2, + value3: annotation3, ) -> None: await websocket.accept() try: @@ -193,10 +197,9 @@ def create_endpoint_3_annotations( except WebSocketDisconnect: await websocket.close() else: + @router.post(path) async def endpoint( - value1: annotation1, - value2: annotation2, - value3: annotation3 + value1: annotation1, value2: annotation2, value3: annotation3 ) -> list[Any]: return [value1, value2, value3]