Browse Source

Added coverage

pull/12529/head
Nir Schulman 9 months ago
parent
commit
01e9da1ca2
  1. 5
      fastapi/dependencies/utils.py
  2. 2
      fastapi/lifespan.py
  3. 16
      fastapi/routing.py
  4. 10
      tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py
  5. 71
      tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py
  6. 21
      tests/test_lifespan_scoped_dependencies/testing_utilities.py

5
fastapi/dependencies/utils.py

@ -388,10 +388,9 @@ def get_endpoint_dependant(
) )
if isinstance(sub_dependant, EndpointDependant): if isinstance(sub_dependant, EndpointDependant):
dependant.endpoint_dependencies.append(sub_dependant) dependant.endpoint_dependencies.append(sub_dependant)
elif isinstance(sub_dependant, LifespanDependant):
dependant.lifespan_dependencies.append(sub_dependant)
else: else:
assert_never(sub_dependant) assert isinstance(sub_dependant, LifespanDependant)
dependant.lifespan_dependencies.append(sub_dependant)
continue continue
if add_non_field_param_to_dependency( if add_non_field_param_to_dependency(
param_name=param_name, param_name=param_name,

2
fastapi/lifespan.py

@ -7,7 +7,7 @@ from fastapi.dependencies.models import LifespanDependant, LifespanDependantCach
from fastapi.dependencies.utils import solve_lifespan_dependant from fastapi.dependencies.utils import solve_lifespan_dependant
from fastapi.routing import APIRoute, APIWebSocketRoute from fastapi.routing import APIRoute, APIWebSocketRoute
if TYPE_CHECKING: if TYPE_CHECKING: # pragma: nocover
from fastapi import FastAPI from fastapi import FastAPI

16
fastapi/routing.py

@ -73,7 +73,7 @@ from starlette.routing import (
from starlette.routing import Mount as Mount # noqa from starlette.routing import Mount as Mount # noqa
from starlette.types import AppType, ASGIApp, Lifespan, Scope from starlette.types import AppType, ASGIApp, Lifespan, Scope
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
from typing_extensions import Annotated, Doc, assert_never, deprecated from typing_extensions import Annotated, Doc, deprecated
def _prepare_response_content( def _prepare_response_content(
@ -407,14 +407,12 @@ class APIWebSocketRoute(routing.WebSocketRoute):
sub_dependant = get_parameterless_sub_dependant( sub_dependant = get_parameterless_sub_dependant(
depends=depends, path=self.path_format, caller=self.__call__, index=i depends=depends, path=self.path_format, caller=self.__call__, index=i
) )
if depends.dependency_scope == "endpoint": if isinstance(sub_dependant, EndpointDependant):
assert isinstance(sub_dependant, EndpointDependant) assert isinstance(sub_dependant, EndpointDependant)
self.dependant.endpoint_dependencies.insert(0, sub_dependant) self.dependant.endpoint_dependencies.insert(0, sub_dependant)
elif depends.dependency_scope == "lifespan": else:
assert isinstance(sub_dependant, LifespanDependant) assert isinstance(sub_dependant, LifespanDependant)
self.dependant.lifespan_dependencies.insert(0, sub_dependant) self.dependant.lifespan_dependencies.insert(0, sub_dependant)
else:
assert_never(depends.dependency_scope)
self._flat_dependant = get_flat_dependant(self.dependant) self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields( self._embed_body_fields = _should_embed_body_fields(
@ -572,14 +570,12 @@ class APIRoute(routing.Route):
sub_dependant = get_parameterless_sub_dependant( sub_dependant = get_parameterless_sub_dependant(
depends=depends, path=self.path_format, caller=self.__call__, index=i depends=depends, path=self.path_format, caller=self.__call__, index=i
) )
if depends.dependency_scope == "endpoint": if isinstance(sub_dependant, EndpointDependant):
assert isinstance(sub_dependant, EndpointDependant)
self.dependant.endpoint_dependencies.insert(0, sub_dependant) self.dependant.endpoint_dependencies.insert(0, sub_dependant)
elif depends.dependency_scope == "lifespan": else:
assert isinstance(sub_dependant, LifespanDependant) assert isinstance(sub_dependant, LifespanDependant)
self.dependant.lifespan_dependencies.insert(0, sub_dependant) self.dependant.lifespan_dependencies.insert(0, sub_dependant)
else:
assert_never(depends.dependency_scope)
self._flat_dependant = get_flat_dependant(self.dependant) self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields( self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params self._flat_dependant.body_params

10
tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py

@ -482,10 +482,10 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_paramete
annotation, is_websocket annotation, is_websocket
): ):
async def dependency_func() -> None: async def dependency_func() -> None:
yield yield # pragma: nocover
async def override_dependency_func(param: annotation) -> None: async def override_dependency_func(param: annotation) -> None:
yield yield # pragma: nocover
app = FastAPI() app = FastAPI()
app.dependency_overrides[dependency_func] = override_dependency_func app.dependency_overrides[dependency_func] = override_dependency_func
@ -551,15 +551,15 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen
depends_class, is_websocket depends_class, is_websocket
): ):
async def sub_dependency() -> None: async def sub_dependency() -> None:
pass pass # pragma: nocover
async def dependency_func() -> None: async def dependency_func() -> None:
yield yield # pragma: nocover
async def override_dependency_func( async def override_dependency_func(
param: Annotated[None, depends_class(sub_dependency)], param: Annotated[None, depends_class(sub_dependency)],
) -> None: ) -> None:
yield yield # pragma: nocover
app = FastAPI() app = FastAPI()

71
tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py

@ -4,6 +4,8 @@ from time import sleep
from typing import Any, AsyncGenerator, Dict, List, Tuple from typing import Any, AsyncGenerator, Dict, List, Tuple
import pytest import pytest
from setuptools import depends
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
BackgroundTasks, BackgroundTasks,
@ -19,6 +21,7 @@ from fastapi import (
Request, Request,
WebSocket, WebSocket,
) )
from fastapi.dependencies.utils import get_endpoint_dependant
from fastapi.exceptions import ( from fastapi.exceptions import (
DependencyScopeConflict, DependencyScopeConflict,
InvalidDependencyScope, InvalidDependencyScope,
@ -443,7 +446,7 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
annotation, is_websocket annotation, is_websocket
): ):
async def dependency_func(param: annotation) -> None: async def dependency_func(param: annotation) -> None:
yield yield # pragma: nocover
app = FastAPI() app = FastAPI()
@ -598,7 +601,8 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
async def shutdown() -> None: async def shutdown() -> None:
nonlocal lifespan_ended nonlocal lifespan_ended
lifespan_ended = True lifespan_ended = True
elif lifespan_style == "events_constructor": else:
assert lifespan_style == "events_constructor"
async def startup() -> None: async def startup() -> None:
nonlocal lifespan_started nonlocal lifespan_started
@ -609,8 +613,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
lifespan_ended = True lifespan_ended = True
app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) app = FastAPI(on_startup=[startup], on_shutdown=[shutdown])
else:
assert_never(lifespan_style)
dependency_factory = DependencyFactory(dependency_style) dependency_factory = DependencyFactory(dependency_style)
@ -641,12 +644,12 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
depends_class, is_websocket depends_class, is_websocket
): ):
async def sub_dependency() -> None: async def sub_dependency() -> None:
pass pass # pragma: nocover
async def dependency_func( async def dependency_func(
param: Annotated[None, depends_class(sub_dependency)], param: Annotated[None, depends_class(sub_dependency)],
) -> None: ) -> None:
yield pass # pragma: nocover
app = FastAPI() app = FastAPI()
@ -730,6 +733,39 @@ def test_endpoints_report_incorrect_dependency_scope(
) )
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_endpoints_report_incorrect_dependency_scope_at_router_scope(
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)
# We intentionally change the dependency scope here to bypass the
# validation at the function level.
depends.dependency_scope = "asdad"
if routing_style == "app_endpoint":
app = FastAPI(dependencies=[depends])
router = app
else:
router = APIRouter(dependencies=[depends])
with pytest.raises(InvalidDependencyScope):
create_endpoint_0_annotations(
router=router,
path="/test",
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) @pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False]) @pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) @pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@ -866,3 +902,26 @@ def test_bad_lifespan_scoped_dependencies(
pass pass
assert exception_info.value.args == (1,) assert exception_info.value.args == (1,)
def test_endpoint_dependant_backwards_compatibility():
dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR)
def endpoint(
dependency1: Annotated[int, Depends(dependency_factory.get_dependency())],
dependency2: Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)],
):
pass # pragma: nocover
dependant = get_endpoint_dependant(
path="/test",
call=endpoint,
name="endpoint",
)
assert dependant.dependencies == tuple(
dependant.lifespan_dependencies +
dependant.endpoint_dependencies
)

21
tests/test_lifespan_scoped_dependencies/testing_utilities.py

@ -1,10 +1,8 @@
import threading
from enum import Enum from enum import Enum
from typing import Any, AsyncGenerator, Generator, List, TypeVar, Union from typing import Any, AsyncGenerator, Generator, List, TypeVar, Union
from fastapi import APIRouter, FastAPI, WebSocket from fastapi import APIRouter, FastAPI, WebSocket
from starlette.testclient import TestClient from fastapi.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
from typing_extensions import assert_never from typing_extensions import assert_never
T = TypeVar("T") T = TypeVar("T")
@ -34,7 +32,6 @@ class DependencyFactory:
self.dependency_style = dependency_style self.dependency_style = dependency_style
self._should_error = should_error self._should_error = should_error
self._value_offset = value_offset self._value_offset = value_offset
self._event = threading.Event()
def get_dependency(self): def get_dependency(self):
if self.dependency_style == DependencyStyle.SYNC_FUNCTION: if self.dependency_style == DependencyStyle.SYNC_FUNCTION:
@ -49,7 +46,7 @@ class DependencyFactory:
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: if self.dependency_style == DependencyStyle.ASYNC_GENERATOR:
return self._asynchronous_generator_dependency return self._asynchronous_generator_dependency
assert_never(self.dependency_style) assert_never(self.dependency_style) # pragma: nocover
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]:
self.activation_times += 1 self.activation_times += 1
@ -58,7 +55,6 @@ class DependencyFactory:
yield self.activation_times + self._value_offset yield self.activation_times + self._value_offset
self.deactivation_times += 1 self.deactivation_times += 1
self._event.set()
def _synchronous_generator_dependency(self) -> Generator[T, None, None]: def _synchronous_generator_dependency(self) -> Generator[T, None, None]:
self.activation_times += 1 self.activation_times += 1
@ -67,7 +63,6 @@ class DependencyFactory:
yield self.activation_times + self._value_offset yield self.activation_times + self._value_offset
self.deactivation_times += 1 self.deactivation_times += 1
self._event.set()
async def _asynchronous_function_dependency(self) -> T: async def _asynchronous_function_dependency(self) -> T:
self.activation_times += 1 self.activation_times += 1
@ -106,10 +101,7 @@ def create_endpoint_0_annotations(
@router.websocket(path) @router.websocket(path)
async def endpoint(websocket: WebSocket) -> None: async def endpoint(websocket: WebSocket) -> None:
await websocket.accept() await websocket.accept()
try:
await websocket.send_json(None) await websocket.send_json(None)
except WebSocketDisconnect:
pass
else: else:
@router.post(path) @router.post(path)
@ -133,10 +125,7 @@ def create_endpoint_1_annotation(
assert value == expected_value assert value == expected_value
await websocket.accept() await websocket.accept()
try:
await websocket.send_json(value) await websocket.send_json(value)
except WebSocketDisconnect:
pass
else: else:
@router.post(path) @router.post(path)
@ -164,10 +153,7 @@ def create_endpoint_2_annotations(
value2: annotation2, value2: annotation2,
) -> None: ) -> None:
await websocket.accept() await websocket.accept()
try:
await websocket.send_json([value1, value2]) await websocket.send_json([value1, value2])
except WebSocketDisconnect:
await websocket.close()
else: else:
@router.post(path) @router.post(path)
@ -197,10 +183,7 @@ def create_endpoint_3_annotations(
value3: annotation3, value3: annotation3,
) -> None: ) -> None:
await websocket.accept() await websocket.accept()
try:
await websocket.send_json([value1, value2, value3]) await websocket.send_json([value1, value2, value3])
except WebSocketDisconnect:
await websocket.close()
else: else:
@router.post(path) @router.post(path)

Loading…
Cancel
Save