Browse Source

Added coverage

pull/12529/head
Nir Schulman 9 months ago
parent
commit
01e9da1ca2
  1. 5
      fastapi/dependencies/utils.py
  2. 2
      fastapi/lifespan.py
  3. 16
      fastapi/routing.py
  4. 10
      tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py
  5. 71
      tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py
  6. 29
      tests/test_lifespan_scoped_dependencies/testing_utilities.py

5
fastapi/dependencies/utils.py

@ -388,10 +388,9 @@ def get_endpoint_dependant(
)
if isinstance(sub_dependant, EndpointDependant):
dependant.endpoint_dependencies.append(sub_dependant)
elif isinstance(sub_dependant, LifespanDependant):
dependant.lifespan_dependencies.append(sub_dependant)
else:
assert_never(sub_dependant)
assert isinstance(sub_dependant, LifespanDependant)
dependant.lifespan_dependencies.append(sub_dependant)
continue
if add_non_field_param_to_dependency(
param_name=param_name,

2
fastapi/lifespan.py

@ -7,7 +7,7 @@ from fastapi.dependencies.models import LifespanDependant, LifespanDependantCach
from fastapi.dependencies.utils import solve_lifespan_dependant
from fastapi.routing import APIRoute, APIWebSocketRoute
if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: nocover
from fastapi import FastAPI

16
fastapi/routing.py

@ -73,7 +73,7 @@ from starlette.routing import (
from starlette.routing import Mount as Mount # noqa
from starlette.types import AppType, ASGIApp, Lifespan, Scope
from starlette.websockets import WebSocket
from typing_extensions import Annotated, Doc, assert_never, deprecated
from typing_extensions import Annotated, Doc, deprecated
def _prepare_response_content(
@ -407,14 +407,12 @@ class APIWebSocketRoute(routing.WebSocketRoute):
sub_dependant = get_parameterless_sub_dependant(
depends=depends, path=self.path_format, caller=self.__call__, index=i
)
if depends.dependency_scope == "endpoint":
if isinstance(sub_dependant, EndpointDependant):
assert isinstance(sub_dependant, EndpointDependant)
self.dependant.endpoint_dependencies.insert(0, sub_dependant)
elif depends.dependency_scope == "lifespan":
else:
assert isinstance(sub_dependant, LifespanDependant)
self.dependant.lifespan_dependencies.insert(0, sub_dependant)
else:
assert_never(depends.dependency_scope)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
@ -572,14 +570,12 @@ class APIRoute(routing.Route):
sub_dependant = get_parameterless_sub_dependant(
depends=depends, path=self.path_format, caller=self.__call__, index=i
)
if depends.dependency_scope == "endpoint":
assert isinstance(sub_dependant, EndpointDependant)
if isinstance(sub_dependant, EndpointDependant):
self.dependant.endpoint_dependencies.insert(0, sub_dependant)
elif depends.dependency_scope == "lifespan":
else:
assert isinstance(sub_dependant, LifespanDependant)
self.dependant.lifespan_dependencies.insert(0, sub_dependant)
else:
assert_never(depends.dependency_scope)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params

10
tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py

@ -482,10 +482,10 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_paramete
annotation, is_websocket
):
async def dependency_func() -> None:
yield
yield # pragma: nocover
async def override_dependency_func(param: annotation) -> None:
yield
yield # pragma: nocover
app = FastAPI()
app.dependency_overrides[dependency_func] = override_dependency_func
@ -551,15 +551,15 @@ def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependen
depends_class, is_websocket
):
async def sub_dependency() -> None:
pass
pass # pragma: nocover
async def dependency_func() -> None:
yield
yield # pragma: nocover
async def override_dependency_func(
param: Annotated[None, depends_class(sub_dependency)],
) -> None:
yield
yield # pragma: nocover
app = FastAPI()

71
tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py

@ -4,6 +4,8 @@ from time import sleep
from typing import Any, AsyncGenerator, Dict, List, Tuple
import pytest
from setuptools import depends
from fastapi import (
APIRouter,
BackgroundTasks,
@ -19,6 +21,7 @@ from fastapi import (
Request,
WebSocket,
)
from fastapi.dependencies.utils import get_endpoint_dependant
from fastapi.exceptions import (
DependencyScopeConflict,
InvalidDependencyScope,
@ -443,7 +446,7 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
annotation, is_websocket
):
async def dependency_func(param: annotation) -> None:
yield
yield # pragma: nocover
app = FastAPI()
@ -598,7 +601,8 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
async def shutdown() -> None:
nonlocal lifespan_ended
lifespan_ended = True
elif lifespan_style == "events_constructor":
else:
assert lifespan_style == "events_constructor"
async def startup() -> None:
nonlocal lifespan_started
@ -609,8 +613,7 @@ def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
lifespan_ended = True
app = FastAPI(on_startup=[startup], on_shutdown=[shutdown])
else:
assert_never(lifespan_style)
dependency_factory = DependencyFactory(dependency_style)
@ -641,12 +644,12 @@ def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
depends_class, is_websocket
):
async def sub_dependency() -> None:
pass
pass # pragma: nocover
async def dependency_func(
param: Annotated[None, depends_class(sub_dependency)],
) -> None:
yield
pass # pragma: nocover
app = FastAPI()
@ -730,6 +733,39 @@ def test_endpoints_report_incorrect_dependency_scope(
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_endpoints_report_incorrect_dependency_scope_at_router_scope(
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)
# We intentionally change the dependency scope here to bypass the
# validation at the function level.
depends.dependency_scope = "asdad"
if routing_style == "app_endpoint":
app = FastAPI(dependencies=[depends])
router = app
else:
router = APIRouter(dependencies=[depends])
with pytest.raises(InvalidDependencyScope):
create_endpoint_0_annotations(
router=router,
path="/test",
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@ -866,3 +902,26 @@ def test_bad_lifespan_scoped_dependencies(
pass
assert exception_info.value.args == (1,)
def test_endpoint_dependant_backwards_compatibility():
dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR)
def endpoint(
dependency1: Annotated[int, Depends(dependency_factory.get_dependency())],
dependency2: Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)],
):
pass # pragma: nocover
dependant = get_endpoint_dependant(
path="/test",
call=endpoint,
name="endpoint",
)
assert dependant.dependencies == tuple(
dependant.lifespan_dependencies +
dependant.endpoint_dependencies
)

29
tests/test_lifespan_scoped_dependencies/testing_utilities.py

@ -1,10 +1,8 @@
import threading
from enum import Enum
from typing import Any, AsyncGenerator, Generator, List, TypeVar, Union
from fastapi import APIRouter, FastAPI, WebSocket
from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
from fastapi.testclient import TestClient
from typing_extensions import assert_never
T = TypeVar("T")
@ -34,7 +32,6 @@ class DependencyFactory:
self.dependency_style = dependency_style
self._should_error = should_error
self._value_offset = value_offset
self._event = threading.Event()
def get_dependency(self):
if self.dependency_style == DependencyStyle.SYNC_FUNCTION:
@ -49,7 +46,7 @@ class DependencyFactory:
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR:
return self._asynchronous_generator_dependency
assert_never(self.dependency_style)
assert_never(self.dependency_style) # pragma: nocover
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]:
self.activation_times += 1
@ -58,7 +55,6 @@ class DependencyFactory:
yield self.activation_times + self._value_offset
self.deactivation_times += 1
self._event.set()
def _synchronous_generator_dependency(self) -> Generator[T, None, None]:
self.activation_times += 1
@ -67,7 +63,6 @@ class DependencyFactory:
yield self.activation_times + self._value_offset
self.deactivation_times += 1
self._event.set()
async def _asynchronous_function_dependency(self) -> T:
self.activation_times += 1
@ -106,10 +101,7 @@ def create_endpoint_0_annotations(
@router.websocket(path)
async def endpoint(websocket: WebSocket) -> None:
await websocket.accept()
try:
await websocket.send_json(None)
except WebSocketDisconnect:
pass
await websocket.send_json(None)
else:
@router.post(path)
@ -133,10 +125,7 @@ def create_endpoint_1_annotation(
assert value == expected_value
await websocket.accept()
try:
await websocket.send_json(value)
except WebSocketDisconnect:
pass
await websocket.send_json(value)
else:
@router.post(path)
@ -164,10 +153,7 @@ def create_endpoint_2_annotations(
value2: annotation2,
) -> None:
await websocket.accept()
try:
await websocket.send_json([value1, value2])
except WebSocketDisconnect:
await websocket.close()
await websocket.send_json([value1, value2])
else:
@router.post(path)
@ -197,10 +183,7 @@ def create_endpoint_3_annotations(
value3: annotation3,
) -> None:
await websocket.accept()
try:
await websocket.send_json([value1, value2, value3])
except WebSocketDisconnect:
await websocket.close()
await websocket.send_json([value1, value2, value3])
else:
@router.post(path)

Loading…
Cancel
Save