From 01e9da1ca212ee84cdde0f0ec3106e268794a2a7 Mon Sep 17 00:00:00 2001 From: Nir Schulman Date: Sat, 9 Nov 2024 12:32:58 +0200 Subject: [PATCH] Added coverage --- fastapi/dependencies/utils.py | 5 +- fastapi/lifespan.py | 2 +- fastapi/routing.py | 16 ++--- .../test_dependency_overrides.py | 10 +-- .../test_endpoint_usage.py | 71 +++++++++++++++++-- .../testing_utilities.py | 29 ++------ 6 files changed, 85 insertions(+), 48 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 96a00c12a..94e983e1e 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -388,10 +388,9 @@ def get_endpoint_dependant( ) if isinstance(sub_dependant, EndpointDependant): dependant.endpoint_dependencies.append(sub_dependant) - elif isinstance(sub_dependant, LifespanDependant): - dependant.lifespan_dependencies.append(sub_dependant) else: - assert_never(sub_dependant) + assert isinstance(sub_dependant, LifespanDependant) + dependant.lifespan_dependencies.append(sub_dependant) continue if add_non_field_param_to_dependency( param_name=param_name, diff --git a/fastapi/lifespan.py b/fastapi/lifespan.py index 9b8afba03..7d53fc00a 100644 --- a/fastapi/lifespan.py +++ b/fastapi/lifespan.py @@ -7,7 +7,7 @@ from fastapi.dependencies.models import LifespanDependant, LifespanDependantCach from fastapi.dependencies.utils import solve_lifespan_dependant from fastapi.routing import APIRoute, APIWebSocketRoute -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: nocover from fastapi import FastAPI diff --git a/fastapi/routing.py b/fastapi/routing.py index 376ad9c8b..7a7344e69 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -73,7 +73,7 @@ from starlette.routing import ( from starlette.routing import Mount as Mount # noqa from starlette.types import AppType, ASGIApp, Lifespan, Scope from starlette.websockets import WebSocket -from typing_extensions import Annotated, Doc, assert_never, deprecated +from typing_extensions import Annotated, Doc, deprecated def _prepare_response_content( @@ -407,14 +407,12 @@ class APIWebSocketRoute(routing.WebSocketRoute): sub_dependant = get_parameterless_sub_dependant( depends=depends, path=self.path_format, caller=self.__call__, index=i ) - if depends.dependency_scope == "endpoint": + if isinstance(sub_dependant, EndpointDependant): assert isinstance(sub_dependant, EndpointDependant) self.dependant.endpoint_dependencies.insert(0, sub_dependant) - elif depends.dependency_scope == "lifespan": + else: assert isinstance(sub_dependant, LifespanDependant) self.dependant.lifespan_dependencies.insert(0, sub_dependant) - else: - assert_never(depends.dependency_scope) self._flat_dependant = get_flat_dependant(self.dependant) self._embed_body_fields = _should_embed_body_fields( @@ -572,14 +570,12 @@ class APIRoute(routing.Route): sub_dependant = get_parameterless_sub_dependant( depends=depends, path=self.path_format, caller=self.__call__, index=i ) - if depends.dependency_scope == "endpoint": - assert isinstance(sub_dependant, EndpointDependant) + if isinstance(sub_dependant, EndpointDependant): self.dependant.endpoint_dependencies.insert(0, sub_dependant) - elif depends.dependency_scope == "lifespan": + else: assert isinstance(sub_dependant, LifespanDependant) self.dependant.lifespan_dependencies.insert(0, sub_dependant) - else: - assert_never(depends.dependency_scope) + self._flat_dependant = get_flat_dependant(self.dependant) self._embed_body_fields = _should_embed_body_fields( self._flat_dependant.body_params diff --git a/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py index 9c7800ee2..eb61ed248 100644 --- a/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py +++ b/tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py @@ -482,10 +482,10 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_paramete annotation, is_websocket ): async def dependency_func() -> None: - yield + yield # pragma: nocover async def override_dependency_func(param: annotation) -> None: - yield + yield # pragma: nocover app = FastAPI() app.dependency_overrides[dependency_func] = override_dependency_func @@ -551,15 +551,15 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen depends_class, is_websocket ): async def sub_dependency() -> None: - pass + pass # pragma: nocover async def dependency_func() -> None: - yield + yield # pragma: nocover async def override_dependency_func( param: Annotated[None, depends_class(sub_dependency)], ) -> None: - yield + yield # pragma: nocover app = FastAPI() diff --git a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py index a6fdea9e4..65c3c2dc1 100644 --- a/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py +++ b/tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py @@ -4,6 +4,8 @@ from time import sleep from typing import Any, AsyncGenerator, Dict, List, Tuple import pytest +from setuptools import depends + from fastapi import ( APIRouter, BackgroundTasks, @@ -19,6 +21,7 @@ from fastapi import ( Request, WebSocket, ) +from fastapi.dependencies.utils import get_endpoint_dependant from fastapi.exceptions import ( DependencyScopeConflict, InvalidDependencyScope, @@ -443,7 +446,7 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( annotation, is_websocket ): async def dependency_func(param: annotation) -> None: - yield + yield # pragma: nocover app = FastAPI() @@ -598,7 +601,8 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( async def shutdown() -> None: nonlocal lifespan_ended lifespan_ended = True - elif lifespan_style == "events_constructor": + else: + assert lifespan_style == "events_constructor" async def startup() -> None: nonlocal lifespan_started @@ -609,8 +613,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( lifespan_ended = True app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) - else: - assert_never(lifespan_style) + dependency_factory = DependencyFactory(dependency_style) @@ -641,12 +644,12 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( depends_class, is_websocket ): async def sub_dependency() -> None: - pass + pass # pragma: nocover async def dependency_func( param: Annotated[None, depends_class(sub_dependency)], ) -> None: - yield + pass # pragma: nocover app = FastAPI() @@ -730,6 +733,39 @@ def test_endpoints_report_incorrect_dependency_scope( ) +@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) +@pytest.mark.parametrize("use_cache", [True, False]) +@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) +@pytest.mark.parametrize("routing_style", ["app", "router"]) +def test_endpoints_report_incorrect_dependency_scope_at_router_scope( + dependency_style: DependencyStyle, routing_style, use_cache, is_websocket +): + dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) + + depends = Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan" + ) + + # We intentionally change the dependency scope here to bypass the + # validation at the function level. + depends.dependency_scope = "asdad" + + if routing_style == "app_endpoint": + app = FastAPI(dependencies=[depends]) + router = app + else: + router = APIRouter(dependencies=[depends]) + + + with pytest.raises(InvalidDependencyScope): + create_endpoint_0_annotations( + router=router, + path="/test", + is_websocket=is_websocket, + ) + + @pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @@ -866,3 +902,26 @@ def test_bad_lifespan_scoped_dependencies( pass assert exception_info.value.args == (1,) + +def test_endpoint_dependant_backwards_compatibility(): + dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) + + def endpoint( + dependency1: Annotated[int, Depends(dependency_factory.get_dependency())], + dependency2: Annotated[int, Depends( + dependency_factory.get_dependency(), + dependency_scope="lifespan" + )], + ): + pass # pragma: nocover + + dependant = get_endpoint_dependant( + path="/test", + call=endpoint, + name="endpoint", + ) + + assert dependant.dependencies == tuple( + dependant.lifespan_dependencies + + dependant.endpoint_dependencies + ) \ No newline at end of file diff --git a/tests/test_lifespan_scoped_dependencies/testing_utilities.py b/tests/test_lifespan_scoped_dependencies/testing_utilities.py index 77054f941..88e0925bc 100644 --- a/tests/test_lifespan_scoped_dependencies/testing_utilities.py +++ b/tests/test_lifespan_scoped_dependencies/testing_utilities.py @@ -1,10 +1,8 @@ -import threading from enum import Enum from typing import Any, AsyncGenerator, Generator, List, TypeVar, Union from fastapi import APIRouter, FastAPI, WebSocket -from starlette.testclient import TestClient -from starlette.websockets import WebSocketDisconnect +from fastapi.testclient import TestClient from typing_extensions import assert_never T = TypeVar("T") @@ -34,7 +32,6 @@ class DependencyFactory: self.dependency_style = dependency_style self._should_error = should_error self._value_offset = value_offset - self._event = threading.Event() def get_dependency(self): if self.dependency_style == DependencyStyle.SYNC_FUNCTION: @@ -49,7 +46,7 @@ class DependencyFactory: if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: return self._asynchronous_generator_dependency - assert_never(self.dependency_style) + assert_never(self.dependency_style) # pragma: nocover async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: self.activation_times += 1 @@ -58,7 +55,6 @@ class DependencyFactory: yield self.activation_times + self._value_offset self.deactivation_times += 1 - self._event.set() def _synchronous_generator_dependency(self) -> Generator[T, None, None]: self.activation_times += 1 @@ -67,7 +63,6 @@ class DependencyFactory: yield self.activation_times + self._value_offset self.deactivation_times += 1 - self._event.set() async def _asynchronous_function_dependency(self) -> T: self.activation_times += 1 @@ -106,10 +101,7 @@ def create_endpoint_0_annotations( @router.websocket(path) async def endpoint(websocket: WebSocket) -> None: await websocket.accept() - try: - await websocket.send_json(None) - except WebSocketDisconnect: - pass + await websocket.send_json(None) else: @router.post(path) @@ -133,10 +125,7 @@ def create_endpoint_1_annotation( assert value == expected_value await websocket.accept() - try: - await websocket.send_json(value) - except WebSocketDisconnect: - pass + await websocket.send_json(value) else: @router.post(path) @@ -164,10 +153,7 @@ def create_endpoint_2_annotations( value2: annotation2, ) -> None: await websocket.accept() - try: - await websocket.send_json([value1, value2]) - except WebSocketDisconnect: - await websocket.close() + await websocket.send_json([value1, value2]) else: @router.post(path) @@ -197,10 +183,7 @@ def create_endpoint_3_annotations( value3: annotation3, ) -> None: await websocket.accept() - try: - await websocket.send_json([value1, value2, value3]) - except WebSocketDisconnect: - await websocket.close() + await websocket.send_json([value1, value2, value3]) else: @router.post(path)