You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

624 lines
19 KiB

from typing import Any, AsyncGenerator, List, Tuple
import pytest
from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Cookie,
Depends,
FastAPI,
File,
Form,
Header,
Path,
Query,
Request,
WebSocket,
)
from fastapi.exceptions import DependencyScopeConflict
from fastapi.params import Security
from fastapi.security import SecurityScopes
from fastapi.testclient import TestClient
from typing_extensions import Annotated, Literal
from tests.test_lifespan_scoped_dependencies.testing_utilities import (
DependencyFactory,
DependencyStyle,
IntentionallyBadDependency,
create_endpoint_0_annotations,
create_endpoint_1_annotation,
create_endpoint_3_annotations,
use_endpoint,
use_websocket,
)
def expect_correct_amount_of_dependency_activations(
*,
app: FastAPI,
dependency_factory: DependencyFactory,
override_dependency_factory: DependencyFactory,
urls_and_responses: List[Tuple[str, Any]],
expected_activation_times: int,
is_websocket: bool,
) -> None:
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
assert override_dependency_factory.activation_times == 0
assert override_dependency_factory.deactivation_times == 0
with TestClient(app) as client:
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
assert override_dependency_factory.activation_times == expected_activation_times
assert override_dependency_factory.deactivation_times == 0
for url, expected_response in urls_and_responses:
if is_websocket:
response = use_websocket(client, url)
else:
response = use_endpoint(client, url)
assert response == expected_response
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
assert (
override_dependency_factory.activation_times
== expected_activation_times
)
assert override_dependency_factory.deactivation_times == 0
assert dependency_factory.activation_times == 0
assert override_dependency_factory.activation_times == expected_activation_times
if dependency_factory.dependency_style not in (
DependencyStyle.SYNC_FUNCTION,
DependencyStyle.ASYNC_FUNCTION,
):
assert dependency_factory.deactivation_times == 0
assert (
override_dependency_factory.deactivation_times == expected_activation_times
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoint_dependencies(
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
None,
Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
),
],
expected_value=11,
)
if routing_style == "router_endpoint":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
urls_and_responses=[("/test", 11)] * 2,
expected_activation_times=1,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("dependency_duplication", [1, 2])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_router_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache,
dependency_duplication,
is_websocket,
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
if routing_style == "app":
app = FastAPI(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations(
router=app, path="/test", is_websocket=is_websocket
)
else:
app = FastAPI()
router = APIRouter(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations(
router=router, path="/test", is_websocket=is_websocket
)
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
urls_and_responses=[("/test", None)] * 2,
expected_activation_times=1 if use_cache else dependency_duplication,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"])
def test_dependency_cache_in_same_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache,
main_dependency_scope: Literal["endpoint", "lifespan"],
is_websocket,
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def dependency(
sub_dependency1: Annotated[int, depends],
sub_dependency2: Annotated[int, depends],
) -> List[int]:
return [sub_dependency1, sub_dependency2]
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
List[int],
Depends(
dependency,
use_cache=use_cache,
dependency_scope=main_dependency_scope,
),
],
)
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
if use_cache:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [11, 11]),
("/test", [11, 11]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket,
)
else:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [11, 12]),
("/test", [11, 12]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=2,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_same_endpoint(
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
create_endpoint_3_annotations(
router=router,
path="/test1",
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)],
)
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
if use_cache:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [11, 11, 11]),
("/test1", [11, 11, 11]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket,
)
else:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [11, 12, 13]),
("/test1", [11, 12, 13]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=3,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_different_endpoints(
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
create_endpoint_3_annotations(
router=router,
path="/test1",
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)],
)
create_endpoint_3_annotations(
router=router,
path="/test2",
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)],
)
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
if use_cache:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [11, 11, 11]),
("/test2", [11, 11, 11]),
("/test1", [11, 11, 11]),
("/test2", [11, 11, 11]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket,
)
else:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [11, 12, 13]),
("/test2", [14, 15, 13]),
("/test1", [11, 12, 13]),
("/test2", [14, 15, 13]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=5,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_no_cached_dependency(
dependency_style: DependencyStyle, routing_style, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=False,
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
)
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
urls_and_responses=[("/test", 11)] * 2,
expected_activation_times=1,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize(
"annotation",
[
Annotated[str, Path()],
Annotated[str, Body()],
Annotated[str, Query()],
Annotated[str, Header()],
SecurityScopes,
Annotated[str, Cookie()],
Annotated[str, Form()],
Annotated[str, File()],
BackgroundTasks,
Request,
WebSocket,
],
)
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
annotation, is_websocket
):
async def dependency_func() -> None:
yield # pragma: nocover
async def override_dependency_func(param: annotation) -> None:
yield # pragma: nocover
app = FastAPI()
app.dependency_overrides[dependency_func] = override_dependency_func
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
None, Depends(dependency_func, dependency_scope="lifespan")
],
)
with pytest.raises(DependencyScopeConflict):
with TestClient(app):
pass
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies(
dependency_style: DependencyStyle, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
async def lifespan_scoped_dependency(
param: Annotated[
int,
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"),
],
) -> AsyncGenerator[int, None]:
yield param
app = FastAPI()
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
int, Depends(lifespan_scoped_dependency, dependency_scope="lifespan")
],
)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
urls_and_responses=[("/test", 11)] * 2,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("depends_class", [Depends, Security])
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
depends_class, is_websocket
):
async def sub_dependency() -> None:
pass # pragma: nocover
async def dependency_func() -> None:
yield # pragma: nocover
async def override_dependency_func(
param: Annotated[None, depends_class(sub_dependency)],
) -> None:
yield # pragma: nocover
app = FastAPI()
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
None, Depends(dependency_func, dependency_scope="lifespan")
],
)
app.dependency_overrides[dependency_func] = override_dependency_func
with pytest.raises(DependencyScopeConflict):
with TestClient(app):
pass
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_bad_override_lifespan_scoped_dependencies(
use_cache, dependency_style: DependencyStyle, routing_style, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, should_error=True)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
)
if routing_style == "router_endpoint":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
with pytest.raises(IntentionallyBadDependency) as exception_info:
with TestClient(app):
pass
assert exception_info.value.args == (1,)