11 changed files with 1834 additions and 794 deletions
@ -1,703 +0,0 @@ |
|||||
from enum import StrEnum, auto |
|
||||
from typing import Any, AsyncGenerator, List, Tuple, TypeVar |
|
||||
|
|
||||
import pytest |
|
||||
from fastapi import ( |
|
||||
APIRouter, |
|
||||
BackgroundTasks, |
|
||||
Body, |
|
||||
Cookie, |
|
||||
Depends, |
|
||||
FastAPI, |
|
||||
File, |
|
||||
Form, |
|
||||
Header, |
|
||||
Path, |
|
||||
Query, |
|
||||
) |
|
||||
from fastapi.exceptions import FastAPIError |
|
||||
from fastapi.params import Security |
|
||||
from fastapi.security import SecurityScopes |
|
||||
from starlette.testclient import TestClient |
|
||||
from typing_extensions import Annotated, Generator, Literal, assert_never |
|
||||
|
|
||||
T = TypeVar('T') |
|
||||
|
|
||||
|
|
||||
class DependencyStyle(StrEnum): |
|
||||
SYNC_FUNCTION = auto() |
|
||||
ASYNC_FUNCTION = auto() |
|
||||
SYNC_GENERATOR = auto() |
|
||||
ASYNC_GENERATOR = auto() |
|
||||
|
|
||||
|
|
||||
class DependencyFactory: |
|
||||
def __init__( |
|
||||
self, |
|
||||
dependency_style: DependencyStyle, *, |
|
||||
should_error: bool = False |
|
||||
): |
|
||||
self.activation_times = 0 |
|
||||
self.deactivation_times = 0 |
|
||||
self.dependency_style = dependency_style |
|
||||
self._should_error = should_error |
|
||||
|
|
||||
def get_dependency(self): |
|
||||
if self.dependency_style == DependencyStyle.SYNC_FUNCTION: |
|
||||
return self._synchronous_function_dependency |
|
||||
|
|
||||
if self.dependency_style == DependencyStyle.SYNC_GENERATOR: |
|
||||
return self._synchronous_generator_dependency |
|
||||
|
|
||||
if self.dependency_style == DependencyStyle.ASYNC_FUNCTION: |
|
||||
return self._asynchronous_function_dependency |
|
||||
|
|
||||
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: |
|
||||
return self._asynchronous_generator_dependency |
|
||||
|
|
||||
assert_never(self.dependency_style) |
|
||||
|
|
||||
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: |
|
||||
self.activation_times += 1 |
|
||||
if self._should_error: |
|
||||
raise ValueError(self.activation_times) |
|
||||
|
|
||||
yield self.activation_times |
|
||||
self.deactivation_times += 1 |
|
||||
|
|
||||
def _synchronous_generator_dependency(self) -> Generator[T, None, None]: |
|
||||
self.activation_times += 1 |
|
||||
if self._should_error: |
|
||||
raise ValueError(self.activation_times) |
|
||||
|
|
||||
yield self.activation_times |
|
||||
self.deactivation_times += 1 |
|
||||
|
|
||||
async def _asynchronous_function_dependency(self) -> T: |
|
||||
self.activation_times += 1 |
|
||||
if self._should_error: |
|
||||
raise ValueError(self.activation_times) |
|
||||
|
|
||||
return self.activation_times |
|
||||
|
|
||||
def _synchronous_function_dependency(self) -> T: |
|
||||
self.activation_times += 1 |
|
||||
if self._should_error: |
|
||||
raise ValueError(self.activation_times) |
|
||||
|
|
||||
return self.activation_times |
|
||||
|
|
||||
|
|
||||
def _expect_correct_amount_of_dependency_activations( |
|
||||
*, |
|
||||
app: FastAPI, |
|
||||
dependency_factory: DependencyFactory, |
|
||||
urls_and_responses: List[Tuple[str, Any]], |
|
||||
expected_activation_times: int |
|
||||
) -> None: |
|
||||
assert dependency_factory.activation_times == 0 |
|
||||
assert dependency_factory.deactivation_times == 0 |
|
||||
with TestClient(app) as client: |
|
||||
assert dependency_factory.activation_times == expected_activation_times |
|
||||
assert dependency_factory.deactivation_times == 0 |
|
||||
|
|
||||
for url, expected_response in urls_and_responses: |
|
||||
response = client.post(url) |
|
||||
response.raise_for_status() |
|
||||
assert response.json() == expected_response |
|
||||
|
|
||||
assert dependency_factory.activation_times == expected_activation_times |
|
||||
assert dependency_factory.deactivation_times == 0 |
|
||||
|
|
||||
assert dependency_factory.activation_times == expected_activation_times |
|
||||
if dependency_factory.dependency_style not in ( |
|
||||
DependencyStyle.SYNC_FUNCTION, |
|
||||
DependencyStyle.ASYNC_FUNCTION |
|
||||
): |
|
||||
assert dependency_factory.deactivation_times == expected_activation_times |
|
||||
|
|
||||
@pytest.mark.parametrize("use_cache", [True, False]) |
|
||||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|
||||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|
||||
def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, use_cache): |
|
||||
dependency_factory= DependencyFactory(dependency_style) |
|
||||
|
|
||||
app = FastAPI() |
|
||||
|
|
||||
if routing_style == "app_endpoint": |
|
||||
router = app |
|
||||
else: |
|
||||
router = APIRouter() |
|
||||
|
|
||||
@router.post("/test") |
|
||||
async def endpoint( |
|
||||
dependency: Annotated[None, Depends( |
|
||||
dependency_factory.get_dependency(), |
|
||||
dependency_scope="lifespan", |
|
||||
use_cache=use_cache, |
|
||||
)] |
|
||||
) -> None: |
|
||||
assert dependency == 1 |
|
||||
return dependency |
|
||||
|
|
||||
if routing_style == "router_endpoint": |
|
||||
app.include_router(router) |
|
||||
|
|
||||
_expect_correct_amount_of_dependency_activations( |
|
||||
app=app, |
|
||||
dependency_factory=dependency_factory, |
|
||||
urls_and_responses=[("/test", 1)] * 2, |
|
||||
expected_activation_times=1 |
|
||||
) |
|
||||
|
|
||||
@pytest.mark.parametrize("use_cache", [True, False]) |
|
||||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|
||||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|
||||
def test_router_dependencies( |
|
||||
dependency_style: DependencyStyle, |
|
||||
routing_style, |
|
||||
use_cache |
|
||||
): |
|
||||
dependency_factory= DependencyFactory(dependency_style) |
|
||||
|
|
||||
depends = Depends( |
|
||||
dependency_factory.get_dependency(), |
|
||||
dependency_scope="lifespan", |
|
||||
use_cache=use_cache |
|
||||
) |
|
||||
|
|
||||
if routing_style == "app": |
|
||||
app = FastAPI(dependencies=[depends]) |
|
||||
|
|
||||
@app.post("/test") |
|
||||
async def endpoint() -> None: |
|
||||
return None |
|
||||
else: |
|
||||
app = FastAPI() |
|
||||
router = APIRouter(dependencies=[depends]) |
|
||||
|
|
||||
@router.post("/test") |
|
||||
async def endpoint() -> None: |
|
||||
return None |
|
||||
|
|
||||
app.include_router(router) |
|
||||
|
|
||||
_expect_correct_amount_of_dependency_activations( |
|
||||
app=app, |
|
||||
dependency_factory=dependency_factory, |
|
||||
urls_and_responses=[("/test", None)] * 2, |
|
||||
expected_activation_times=1 |
|
||||
) |
|
||||
|
|
||||
|
|
||||
@pytest.mark.parametrize("use_cache", [True, False]) |
|
||||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|
||||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|
||||
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) |
|
||||
def test_dependency_cache_in_same_dependency( |
|
||||
dependency_style: DependencyStyle, |
|
||||
routing_style, |
|
||||
use_cache, |
|
||||
main_dependency_scope: Literal["endpoint", "lifespan"] |
|
||||
): |
|
||||
dependency_factory= DependencyFactory(dependency_style) |
|
||||
|
|
||||
depends = Depends( |
|
||||
dependency_factory.get_dependency(), |
|
||||
dependency_scope="lifespan", |
|
||||
use_cache=use_cache |
|
||||
) |
|
||||
|
|
||||
app = FastAPI() |
|
||||
|
|
||||
if routing_style == "app": |
|
||||
router = app |
|
||||
|
|
||||
else: |
|
||||
router = APIRouter() |
|
||||
|
|
||||
async def dependency( |
|
||||
sub_dependency1: Annotated[int, depends], |
|
||||
sub_dependency2: Annotated[int, depends], |
|
||||
) -> List[int]: |
|
||||
return [sub_dependency1, sub_dependency2] |
|
||||
|
|
||||
@router.post("/test") |
|
||||
async def endpoint( |
|
||||
dependency: Annotated[List[int], Depends( |
|
||||
dependency, |
|
||||
use_cache=use_cache, |
|
||||
dependency_scope=main_dependency_scope, |
|
||||
)] |
|
||||
) -> List[int]: |
|
||||
return dependency |
|
||||
|
|
||||
if routing_style == "router": |
|
||||
app.include_router(router) |
|
||||
|
|
||||
if use_cache: |
|
||||
_expect_correct_amount_of_dependency_activations( |
|
||||
app=app, |
|
||||
urls_and_responses=[ |
|
||||
("/test", [1, 1]), |
|
||||
("/test", [1, 1]), |
|
||||
], |
|
||||
dependency_factory=dependency_factory, |
|
||||
expected_activation_times=1 |
|
||||
) |
|
||||
else: |
|
||||
_expect_correct_amount_of_dependency_activations( |
|
||||
app=app, |
|
||||
urls_and_responses=[ |
|
||||
("/test", [1, 2]), |
|
||||
("/test", [1, 2]), |
|
||||
], |
|
||||
dependency_factory=dependency_factory, |
|
||||
expected_activation_times=2 |
|
||||
) |
|
||||
|
|
||||
|
|
||||
@pytest.mark.parametrize("use_cache", [True, False]) |
|
||||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|
||||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|
||||
def test_dependency_cache_in_same_endpoint( |
|
||||
dependency_style: DependencyStyle, |
|
||||
routing_style, |
|
||||
use_cache |
|
||||
): |
|
||||
dependency_factory= DependencyFactory(dependency_style) |
|
||||
|
|
||||
depends = Depends( |
|
||||
dependency_factory.get_dependency(), |
|
||||
dependency_scope="lifespan", |
|
||||
use_cache=use_cache |
|
||||
) |
|
||||
|
|
||||
app = FastAPI() |
|
||||
|
|
||||
if routing_style == "app": |
|
||||
router = app |
|
||||
|
|
||||
else: |
|
||||
router = APIRouter() |
|
||||
|
|
||||
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
|
||||
return dependency3 |
|
||||
|
|
||||
@router.post("/test1") |
|
||||
async def endpoint( |
|
||||
dependency1: Annotated[int, depends], |
|
||||
dependency2: Annotated[int, depends], |
|
||||
dependency3: Annotated[int, Depends(endpoint_dependency)] |
|
||||
) -> List[int]: |
|
||||
return [dependency1, dependency2, dependency3] |
|
||||
|
|
||||
if routing_style == "router": |
|
||||
app.include_router(router) |
|
||||
|
|
||||
if use_cache: |
|
||||
_expect_correct_amount_of_dependency_activations( |
|
||||
app=app, |
|
||||
urls_and_responses=[ |
|
||||
("/test1", [1, 1, 1]), |
|
||||
("/test1", [1, 1, 1]), |
|
||||
], |
|
||||
dependency_factory=dependency_factory, |
|
||||
expected_activation_times=1 |
|
||||
) |
|
||||
else: |
|
||||
_expect_correct_amount_of_dependency_activations( |
|
||||
app=app, |
|
||||
urls_and_responses=[ |
|
||||
("/test1", [1, 2, 3]), |
|
||||
("/test1", [1, 2, 3]), |
|
||||
], |
|
||||
dependency_factory=dependency_factory, |
|
||||
expected_activation_times=3 |
|
||||
) |
|
||||
|
|
||||
@pytest.mark.parametrize("use_cache", [True, False]) |
|
||||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|
||||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|
||||
def test_dependency_cache_in_different_endpoints( |
|
||||
dependency_style: DependencyStyle, |
|
||||
routing_style, |
|
||||
use_cache |
|
||||
): |
|
||||
dependency_factory= DependencyFactory(dependency_style) |
|
||||
|
|
||||
depends = Depends( |
|
||||
dependency_factory.get_dependency(), |
|
||||
dependency_scope="lifespan", |
|
||||
use_cache=use_cache |
|
||||
) |
|
||||
|
|
||||
app = FastAPI() |
|
||||
|
|
||||
if routing_style == "app": |
|
||||
router = app |
|
||||
|
|
||||
else: |
|
||||
router = APIRouter() |
|
||||
|
|
||||
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
|
||||
return dependency3 |
|
||||
|
|
||||
@router.post("/test1") |
|
||||
async def endpoint( |
|
||||
dependency1: Annotated[int, depends], |
|
||||
dependency2: Annotated[int, depends], |
|
||||
dependency3: Annotated[int, Depends(endpoint_dependency)] |
|
||||
) -> List[int]: |
|
||||
return [dependency1, dependency2, dependency3] |
|
||||
|
|
||||
@router.post("/test2") |
|
||||
async def endpoint2( |
|
||||
dependency1: Annotated[int, depends], |
|
||||
dependency2: Annotated[int, depends], |
|
||||
dependency3: Annotated[int, Depends(endpoint_dependency)] |
|
||||
) -> List[int]: |
|
||||
return [dependency1, dependency2, dependency3] |
|
||||
|
|
||||
if routing_style == "router": |
|
||||
app.include_router(router) |
|
||||
|
|
||||
if use_cache: |
|
||||
_expect_correct_amount_of_dependency_activations( |
|
||||
app=app, |
|
||||
urls_and_responses=[ |
|
||||
("/test1", [1, 1, 1]), |
|
||||
("/test2", [1, 1, 1]), |
|
||||
("/test1", [1, 1, 1]), |
|
||||
("/test2", [1, 1, 1]), |
|
||||
], |
|
||||
dependency_factory=dependency_factory, |
|
||||
expected_activation_times=1 |
|
||||
) |
|
||||
else: |
|
||||
_expect_correct_amount_of_dependency_activations( |
|
||||
app=app, |
|
||||
urls_and_responses=[ |
|
||||
("/test1", [1, 2, 3]), |
|
||||
("/test2", [4, 5, 3]), |
|
||||
("/test1", [1, 2, 3]), |
|
||||
("/test2", [4, 5, 3]), |
|
||||
], |
|
||||
dependency_factory=dependency_factory, |
|
||||
expected_activation_times=5 |
|
||||
) |
|
||||
|
|
||||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|
||||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|
||||
def test_no_cached_dependency( |
|
||||
dependency_style: DependencyStyle, |
|
||||
routing_style, |
|
||||
): |
|
||||
dependency_factory= DependencyFactory(dependency_style) |
|
||||
|
|
||||
depends = Depends( |
|
||||
dependency_factory.get_dependency(), |
|
||||
dependency_scope="lifespan", |
|
||||
use_cache=False |
|
||||
) |
|
||||
|
|
||||
app = FastAPI() |
|
||||
|
|
||||
if routing_style == "app": |
|
||||
router = app |
|
||||
|
|
||||
else: |
|
||||
router = APIRouter() |
|
||||
|
|
||||
@router.post("/test") |
|
||||
async def endpoint( |
|
||||
dependency: Annotated[int, depends], |
|
||||
) -> int: |
|
||||
return dependency |
|
||||
|
|
||||
if routing_style == "router": |
|
||||
app.include_router(router) |
|
||||
|
|
||||
_expect_correct_amount_of_dependency_activations( |
|
||||
app=app, |
|
||||
dependency_factory=dependency_factory, |
|
||||
urls_and_responses=[("/test", 1)] * 2, |
|
||||
expected_activation_times=1 |
|
||||
) |
|
||||
|
|
||||
|
|
||||
@pytest.mark.parametrize("annotation", [ |
|
||||
Annotated[str, Path()], |
|
||||
Annotated[str, Body()], |
|
||||
Annotated[str, Query()], |
|
||||
Annotated[str, Header()], |
|
||||
SecurityScopes, |
|
||||
Annotated[str, Cookie()], |
|
||||
Annotated[str, Form()], |
|
||||
Annotated[str, File()], |
|
||||
BackgroundTasks, |
|
||||
]) |
|
||||
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( |
|
||||
annotation |
|
||||
): |
|
||||
async def dependency_func(param: annotation) -> None: |
|
||||
yield |
|
||||
|
|
||||
app = FastAPI() |
|
||||
|
|
||||
with pytest.raises(FastAPIError): |
|
||||
@app.post("/test") |
|
||||
async def endpoint( |
|
||||
dependency: Annotated[ |
|
||||
None, Depends(dependency_func, dependency_scope="lifespan")] |
|
||||
) -> None: |
|
||||
return |
|
||||
|
|
||||
|
|
||||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|
||||
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( |
|
||||
dependency_style: DependencyStyle |
|
||||
): |
|
||||
dependency_factory = DependencyFactory(dependency_style) |
|
||||
|
|
||||
async def lifespan_scoped_dependency( |
|
||||
param: Annotated[int, Depends( |
|
||||
dependency_factory.get_dependency(), |
|
||||
dependency_scope="lifespan" |
|
||||
)] |
|
||||
) -> AsyncGenerator[int, None]: |
|
||||
yield param |
|
||||
|
|
||||
app = FastAPI() |
|
||||
|
|
||||
@app.post("/test") |
|
||||
async def endpoint( |
|
||||
dependency: Annotated[int, Depends( |
|
||||
lifespan_scoped_dependency, |
|
||||
dependency_scope="lifespan" |
|
||||
)] |
|
||||
) -> int: |
|
||||
return dependency |
|
||||
|
|
||||
_expect_correct_amount_of_dependency_activations( |
|
||||
app=app, |
|
||||
dependency_factory=dependency_factory, |
|
||||
expected_activation_times=1, |
|
||||
urls_and_responses=[("/test", 1)] * 2 |
|
||||
) |
|
||||
|
|
||||
|
|
||||
@pytest.mark.parametrize("depends_class", [Depends, Security]) |
|
||||
@pytest.mark.parametrize("route_type", [FastAPI.post, FastAPI.websocket], ids=[ |
|
||||
"websocket", "endpoint" |
|
||||
]) |
|
||||
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( |
|
||||
depends_class, |
|
||||
route_type |
|
||||
): |
|
||||
async def sub_dependency() -> None: |
|
||||
pass |
|
||||
|
|
||||
async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: |
|
||||
yield |
|
||||
|
|
||||
app = FastAPI() |
|
||||
route_decorator = route_type(app, "/test") |
|
||||
|
|
||||
with pytest.raises(FastAPIError): |
|
||||
@route_decorator |
|
||||
async def endpoint(x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] |
|
||||
) -> None: |
|
||||
return |
|
||||
|
|
||||
@pytest.mark.parametrize("use_cache", [True, False]) |
|
||||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|
||||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|
||||
def test_dependencies_must_provide_correct_dependency_scope( |
|
||||
dependency_style: DependencyStyle, |
|
||||
routing_style, |
|
||||
use_cache |
|
||||
): |
|
||||
dependency_factory= DependencyFactory(dependency_style) |
|
||||
|
|
||||
app = FastAPI() |
|
||||
|
|
||||
if routing_style == "app_endpoint": |
|
||||
router = app |
|
||||
else: |
|
||||
router = APIRouter() |
|
||||
|
|
||||
with pytest.raises(FastAPIError): |
|
||||
@router.post("/test") |
|
||||
async def endpoint( |
|
||||
dependency: Annotated[None, Depends( |
|
||||
dependency_factory.get_dependency(), |
|
||||
dependency_scope="incorrect", |
|
||||
use_cache=use_cache, |
|
||||
)] |
|
||||
) -> None: |
|
||||
assert dependency == 1 |
|
||||
return dependency |
|
||||
|
|
||||
|
|
||||
@pytest.mark.parametrize("use_cache", [True, False]) |
|
||||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|
||||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|
||||
def test_endpoints_report_incorrect_dependency_scope( |
|
||||
dependency_style: DependencyStyle, |
|
||||
routing_style, |
|
||||
use_cache |
|
||||
): |
|
||||
dependency_factory= DependencyFactory(dependency_style) |
|
||||
|
|
||||
app = FastAPI() |
|
||||
|
|
||||
if routing_style == "app_endpoint": |
|
||||
router = app |
|
||||
else: |
|
||||
router = APIRouter() |
|
||||
|
|
||||
depends = Depends( |
|
||||
dependency_factory.get_dependency(), |
|
||||
dependency_scope="lifespan", |
|
||||
use_cache=use_cache, |
|
||||
) |
|
||||
# We intentionally change the dependency scope here to bypass the |
|
||||
# validation at the function level. |
|
||||
depends.dependency_scope = "asdad" |
|
||||
|
|
||||
with pytest.raises(FastAPIError): |
|
||||
@router.post("/test") |
|
||||
async def endpoint( |
|
||||
dependency: Annotated[int, depends] |
|
||||
) -> int: |
|
||||
assert dependency == 1 |
|
||||
return dependency |
|
||||
|
|
||||
|
|
||||
@pytest.mark.parametrize("use_cache", [True, False]) |
|
||||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|
||||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|
||||
def test_endpoints_report_uninitialized_dependency( |
|
||||
dependency_style: DependencyStyle, |
|
||||
routing_style, |
|
||||
use_cache |
|
||||
): |
|
||||
dependency_factory= DependencyFactory(dependency_style) |
|
||||
|
|
||||
app = FastAPI() |
|
||||
|
|
||||
if routing_style == "app_endpoint": |
|
||||
router = app |
|
||||
else: |
|
||||
router = APIRouter() |
|
||||
|
|
||||
depends = Depends( |
|
||||
dependency_factory.get_dependency(), |
|
||||
dependency_scope="lifespan", |
|
||||
use_cache=use_cache, |
|
||||
) |
|
||||
|
|
||||
@router.post("/test") |
|
||||
async def endpoint( |
|
||||
dependency: Annotated[int, depends] |
|
||||
) -> int: |
|
||||
assert dependency == 1 |
|
||||
return dependency |
|
||||
|
|
||||
if routing_style == "router_endpoint": |
|
||||
app.include_router(router) |
|
||||
|
|
||||
with TestClient(app) as client: |
|
||||
dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] |
|
||||
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {} |
|
||||
|
|
||||
try: |
|
||||
with pytest.raises(FastAPIError): |
|
||||
client.post("/test") |
|
||||
finally: |
|
||||
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies |
|
||||
|
|
||||
|
|
||||
@pytest.mark.parametrize("use_cache", [True, False]) |
|
||||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|
||||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|
||||
def test_endpoints_report_uninitialized_internal_lifespan( |
|
||||
dependency_style: DependencyStyle, |
|
||||
routing_style, |
|
||||
use_cache |
|
||||
): |
|
||||
dependency_factory= DependencyFactory(dependency_style) |
|
||||
|
|
||||
app = FastAPI() |
|
||||
|
|
||||
if routing_style == "app_endpoint": |
|
||||
router = app |
|
||||
else: |
|
||||
router = APIRouter() |
|
||||
|
|
||||
depends = Depends( |
|
||||
dependency_factory.get_dependency(), |
|
||||
dependency_scope="lifespan", |
|
||||
use_cache=use_cache, |
|
||||
) |
|
||||
|
|
||||
@router.post("/test") |
|
||||
async def endpoint( |
|
||||
dependency: Annotated[int, depends] |
|
||||
) -> int: |
|
||||
assert dependency == 1 |
|
||||
return dependency |
|
||||
|
|
||||
if routing_style == "router_endpoint": |
|
||||
app.include_router(router) |
|
||||
|
|
||||
with TestClient(app) as client: |
|
||||
internal_state = client.app_state["__fastapi__"] |
|
||||
del client.app_state["__fastapi__"] |
|
||||
|
|
||||
try: |
|
||||
with pytest.raises(FastAPIError): |
|
||||
client.post("/test") |
|
||||
finally: |
|
||||
client.app_state["__fastapi__"] = internal_state |
|
||||
|
|
||||
|
|
||||
@pytest.mark.parametrize("use_cache", [True, False]) |
|
||||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|
||||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|
||||
def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: DependencyStyle, routing_style): |
|
||||
dependency_factory= DependencyFactory(dependency_style, should_error=True) |
|
||||
depends = Depends( |
|
||||
dependency_factory.get_dependency(), |
|
||||
dependency_scope="lifespan", |
|
||||
use_cache=use_cache, |
|
||||
) |
|
||||
|
|
||||
app = FastAPI() |
|
||||
|
|
||||
if routing_style == "app_endpoint": |
|
||||
router = app |
|
||||
|
|
||||
else: |
|
||||
router = APIRouter() |
|
||||
|
|
||||
@router.post("/test") |
|
||||
async def endpoint( |
|
||||
dependency: Annotated[int, depends] |
|
||||
) -> int: |
|
||||
assert dependency == 1 |
|
||||
return dependency |
|
||||
|
|
||||
if routing_style == "router_endpoint": |
|
||||
app.include_router(router) |
|
||||
|
|
||||
with pytest.raises(ValueError) as exception_info: |
|
||||
with TestClient(app): |
|
||||
pass |
|
||||
|
|
||||
assert exception_info.value.args == (1,) |
|
||||
|
|
||||
|
|
||||
# TODO: Add tests for dependency_overrides |
|
||||
# TODO: Add a websocket equivalent to all tests |
|
@ -0,0 +1,634 @@ |
|||||
|
from typing import Any, AsyncGenerator, List, Tuple |
||||
|
|
||||
|
import pytest |
||||
|
from fastapi import ( |
||||
|
APIRouter, |
||||
|
BackgroundTasks, |
||||
|
Body, |
||||
|
Cookie, |
||||
|
Depends, |
||||
|
FastAPI, |
||||
|
File, |
||||
|
Form, |
||||
|
Header, |
||||
|
Path, |
||||
|
Query, |
||||
|
) |
||||
|
from fastapi.exceptions import DependencyScopeConflict |
||||
|
from fastapi.params import Security |
||||
|
from fastapi.security import SecurityScopes |
||||
|
from fastapi.testclient import TestClient |
||||
|
from typing_extensions import Annotated, Literal |
||||
|
|
||||
|
from tests.test_lifespan_scoped_dependencies.testing_utilities import ( |
||||
|
DependencyFactory, |
||||
|
DependencyStyle, |
||||
|
IntentionallyBadDependency, |
||||
|
create_endpoint_0_annotations, |
||||
|
create_endpoint_1_annotation, |
||||
|
create_endpoint_3_annotations, |
||||
|
use_endpoint, |
||||
|
use_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
def expect_correct_amount_of_dependency_activations( |
||||
|
*, |
||||
|
app: FastAPI, |
||||
|
dependency_factory: DependencyFactory, |
||||
|
override_dependency_factory: DependencyFactory, |
||||
|
urls_and_responses: List[Tuple[str, Any]], |
||||
|
expected_activation_times: int, |
||||
|
is_websocket: bool |
||||
|
) -> None: |
||||
|
assert dependency_factory.activation_times == 0 |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
assert override_dependency_factory.activation_times == 0 |
||||
|
assert override_dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
with TestClient(app) as client: |
||||
|
assert dependency_factory.activation_times == 0 |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
assert override_dependency_factory.activation_times == expected_activation_times |
||||
|
assert override_dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
for url, expected_response in urls_and_responses: |
||||
|
if is_websocket: |
||||
|
response = use_websocket(client, url) |
||||
|
else: |
||||
|
response = use_endpoint(client, url) |
||||
|
|
||||
|
assert response == expected_response |
||||
|
|
||||
|
assert dependency_factory.activation_times == 0 |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
assert override_dependency_factory.activation_times == expected_activation_times |
||||
|
assert override_dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
assert dependency_factory.activation_times == 0 |
||||
|
assert override_dependency_factory.activation_times == expected_activation_times |
||||
|
if dependency_factory.dependency_style not in ( |
||||
|
DependencyStyle.SYNC_FUNCTION, |
||||
|
DependencyStyle.ASYNC_FUNCTION |
||||
|
): |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
assert override_dependency_factory.deactivation_times == expected_activation_times |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_endpoint_dependencies( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory( |
||||
|
dependency_style, |
||||
|
value_offset=10 |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[ |
||||
|
None, |
||||
|
Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
], |
||||
|
expected_value=11 |
||||
|
) |
||||
|
if routing_style == "router_endpoint": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
urls_and_responses=[("/test", 11)] * 2, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("dependency_duplication", [1, 2]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_router_dependencies( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
dependency_duplication, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory= DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory( |
||||
|
dependency_style, |
||||
|
value_offset=10 |
||||
|
) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache |
||||
|
) |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
app = FastAPI(dependencies=[depends] * dependency_duplication) |
||||
|
|
||||
|
create_endpoint_0_annotations( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
else: |
||||
|
app = FastAPI() |
||||
|
router = APIRouter(dependencies=[depends] * dependency_duplication) |
||||
|
|
||||
|
create_endpoint_0_annotations( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
app.include_router(router) |
||||
|
|
||||
|
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
urls_and_responses=[("/test", None)] * 2, |
||||
|
expected_activation_times=1 if use_cache else dependency_duplication, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) |
||||
|
def test_dependency_cache_in_same_dependency( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
main_dependency_scope: Literal["endpoint", "lifespan"], |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory= DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory( |
||||
|
dependency_style, |
||||
|
value_offset=10 |
||||
|
) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
async def dependency( |
||||
|
sub_dependency1: Annotated[int, depends], |
||||
|
sub_dependency2: Annotated[int, depends], |
||||
|
) -> List[int]: |
||||
|
return [sub_dependency1, sub_dependency2] |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[List[int], Depends( |
||||
|
dependency, |
||||
|
use_cache=use_cache, |
||||
|
dependency_scope=main_dependency_scope, |
||||
|
)] |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
app.dependency_overrides[ |
||||
|
dependency_factory.get_dependency() |
||||
|
] = override_dependency_factory.get_dependency() |
||||
|
|
||||
|
if use_cache: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test", [11, 11]), |
||||
|
("/test", [11, 11]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
else: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test", [11, 12]), |
||||
|
("/test", [11, 12]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
expected_activation_times=2, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_dependency_cache_in_same_endpoint( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory= DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory( |
||||
|
dependency_style, |
||||
|
value_offset=10 |
||||
|
) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
||||
|
return dependency3 |
||||
|
|
||||
|
create_endpoint_3_annotations( |
||||
|
router=router, |
||||
|
path="/test1", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation1=Annotated[int, depends], |
||||
|
annotation2=Annotated[int, depends], |
||||
|
annotation3=Annotated[int, Depends(endpoint_dependency)], |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
app.dependency_overrides[ |
||||
|
dependency_factory.get_dependency() |
||||
|
] = override_dependency_factory.get_dependency() |
||||
|
|
||||
|
if use_cache: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test1", [11, 11, 11]), |
||||
|
("/test1", [11, 11, 11]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
else: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test1", [11, 12, 13]), |
||||
|
("/test1", [11, 12, 13]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
expected_activation_times=3, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_dependency_cache_in_different_endpoints( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory= DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory( |
||||
|
dependency_style, |
||||
|
value_offset=10 |
||||
|
) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
||||
|
return dependency3 |
||||
|
|
||||
|
create_endpoint_3_annotations( |
||||
|
router=router, |
||||
|
path="/test1", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation1=Annotated[int, depends], |
||||
|
annotation2=Annotated[int, depends], |
||||
|
annotation3=Annotated[int, Depends(endpoint_dependency)], |
||||
|
) |
||||
|
|
||||
|
create_endpoint_3_annotations( |
||||
|
router=router, |
||||
|
path="/test2", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation1=Annotated[int, depends], |
||||
|
annotation2=Annotated[int, depends], |
||||
|
annotation3=Annotated[int, Depends(endpoint_dependency)], |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
app.dependency_overrides[ |
||||
|
dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() |
||||
|
|
||||
|
if use_cache: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test1", [11, 11, 11]), |
||||
|
("/test2", [11, 11, 11]), |
||||
|
("/test1", [11, 11, 11]), |
||||
|
("/test2", [11, 11, 11]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
else: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test1", [11, 12, 13]), |
||||
|
("/test2", [14, 15, 13]), |
||||
|
("/test1", [11, 12, 13]), |
||||
|
("/test2", [14, 15, 13]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
expected_activation_times=5, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_no_cached_dependency( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory= DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory( |
||||
|
dependency_style, |
||||
|
value_offset=10 |
||||
|
) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=False |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, depends], |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
app.dependency_overrides[ |
||||
|
dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
urls_and_responses=[("/test", 11)] * 2, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("annotation", [ |
||||
|
Annotated[str, Path()], |
||||
|
Annotated[str, Body()], |
||||
|
Annotated[str, Query()], |
||||
|
Annotated[str, Header()], |
||||
|
SecurityScopes, |
||||
|
Annotated[str, Cookie()], |
||||
|
Annotated[str, Form()], |
||||
|
Annotated[str, File()], |
||||
|
BackgroundTasks, |
||||
|
]) |
||||
|
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( |
||||
|
annotation, |
||||
|
is_websocket |
||||
|
): |
||||
|
async def dependency_func() -> None: |
||||
|
yield |
||||
|
|
||||
|
async def override_dependency_func(param: annotation) -> None: |
||||
|
yield |
||||
|
|
||||
|
app = FastAPI() |
||||
|
app.dependency_overrides[dependency_func] = override_dependency_func |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[None, |
||||
|
Depends(dependency_func, dependency_scope="lifespan") |
||||
|
] |
||||
|
) |
||||
|
|
||||
|
with pytest.raises(DependencyScopeConflict): |
||||
|
with TestClient(app): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies( |
||||
|
dependency_style: DependencyStyle, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory( |
||||
|
dependency_style, |
||||
|
value_offset=10 |
||||
|
) |
||||
|
|
||||
|
async def lifespan_scoped_dependency( |
||||
|
param: Annotated[int, Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan" |
||||
|
)] |
||||
|
) -> AsyncGenerator[int, None]: |
||||
|
yield param |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[ |
||||
|
int, |
||||
|
Depends(lifespan_scoped_dependency, dependency_scope="lifespan") |
||||
|
], |
||||
|
) |
||||
|
|
||||
|
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
urls_and_responses=[("/test", 11)] * 2, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("depends_class", [Depends, Security]) |
||||
|
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( |
||||
|
depends_class, |
||||
|
is_websocket |
||||
|
): |
||||
|
async def sub_dependency() -> None: |
||||
|
pass |
||||
|
|
||||
|
async def dependency_func() -> None: |
||||
|
yield |
||||
|
|
||||
|
async def override_dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: |
||||
|
yield |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] |
||||
|
) |
||||
|
|
||||
|
app.dependency_overrides[dependency_func] = override_dependency_func |
||||
|
|
||||
|
with pytest.raises(DependencyScopeConflict): |
||||
|
with TestClient(app): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_bad_override_lifespan_scoped_dependencies( |
||||
|
use_cache, |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory= DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory(dependency_style, should_error=True) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, depends] |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router_endpoint": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() |
||||
|
|
||||
|
with pytest.raises(IntentionallyBadDependency) as exception_info: |
||||
|
with TestClient(app): |
||||
|
pass |
||||
|
|
||||
|
assert exception_info.value.args == (1,) |
@ -0,0 +1,854 @@ |
|||||
|
import warnings |
||||
|
from contextlib import asynccontextmanager |
||||
|
from typing import Any, AsyncGenerator, Dict, List, Tuple |
||||
|
|
||||
|
import pytest |
||||
|
from fastapi import ( |
||||
|
APIRouter, |
||||
|
BackgroundTasks, |
||||
|
Body, |
||||
|
Cookie, |
||||
|
Depends, |
||||
|
FastAPI, |
||||
|
File, |
||||
|
Form, |
||||
|
Header, |
||||
|
Path, |
||||
|
Query, |
||||
|
) |
||||
|
from fastapi.exceptions import ( |
||||
|
DependencyScopeConflict, |
||||
|
InvalidDependencyScope, |
||||
|
UninitializedLifespanDependency, |
||||
|
) |
||||
|
from fastapi.params import Security |
||||
|
from fastapi.security import SecurityScopes |
||||
|
from fastapi.testclient import TestClient |
||||
|
from typing_extensions import Annotated, Literal, assert_never |
||||
|
|
||||
|
from tests.test_lifespan_scoped_dependencies.testing_utilities import ( |
||||
|
DependencyFactory, |
||||
|
DependencyStyle, |
||||
|
IntentionallyBadDependency, |
||||
|
create_endpoint_0_annotations, |
||||
|
create_endpoint_1_annotation, |
||||
|
create_endpoint_2_annotations, |
||||
|
create_endpoint_3_annotations, |
||||
|
use_endpoint, |
||||
|
use_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
def expect_correct_amount_of_dependency_activations( |
||||
|
*, |
||||
|
app: FastAPI, |
||||
|
dependency_factory: DependencyFactory, |
||||
|
urls_and_responses: List[Tuple[str, Any]], |
||||
|
expected_activation_times: int, |
||||
|
is_websocket: bool |
||||
|
) -> None: |
||||
|
assert dependency_factory.activation_times == 0 |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
with TestClient(app) as client: |
||||
|
assert dependency_factory.activation_times == expected_activation_times |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
for url, expected_response in urls_and_responses: |
||||
|
if is_websocket: |
||||
|
assert use_websocket(client, url) == expected_response |
||||
|
else: |
||||
|
assert use_endpoint(client, url) == expected_response |
||||
|
|
||||
|
assert dependency_factory.activation_times == expected_activation_times |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
assert dependency_factory.activation_times == expected_activation_times |
||||
|
if dependency_factory.dependency_style not in ( |
||||
|
DependencyStyle.SYNC_FUNCTION, |
||||
|
DependencyStyle.ASYNC_FUNCTION |
||||
|
): |
||||
|
assert dependency_factory.deactivation_times == expected_activation_times |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False], ids=["With Cache", "Without Cache"]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_endpoint_dependencies( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
is_websocket: bool, |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[None, Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
)], |
||||
|
expected_value=1 |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router_endpoint": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
urls_and_responses=[("/test", 1)] * 2, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("dependency_duplication", [1, 2]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_router_dependencies( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
dependency_duplication, |
||||
|
is_websocket: bool, |
||||
|
): |
||||
|
dependency_factory= DependencyFactory(dependency_style) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache |
||||
|
) |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
app = FastAPI(dependencies=[depends] * dependency_duplication) |
||||
|
|
||||
|
create_endpoint_0_annotations( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
else: |
||||
|
app = FastAPI() |
||||
|
router = APIRouter(dependencies=[depends] * dependency_duplication) |
||||
|
|
||||
|
create_endpoint_0_annotations( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
app.include_router(router) |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
urls_and_responses=[("/test", None)] * 2, |
||||
|
expected_activation_times=1 if use_cache else dependency_duplication, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) |
||||
|
def test_dependency_cache_in_same_dependency( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
main_dependency_scope: Literal["endpoint", "lifespan"], |
||||
|
is_websocket: bool, |
||||
|
): |
||||
|
dependency_factory= DependencyFactory(dependency_style) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
async def dependency( |
||||
|
sub_dependency1: Annotated[int, depends], |
||||
|
sub_dependency2: Annotated[int, depends], |
||||
|
) -> List[int]: |
||||
|
return [sub_dependency1, sub_dependency2] |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[List[int], Depends( |
||||
|
dependency, |
||||
|
use_cache=use_cache, |
||||
|
dependency_scope=main_dependency_scope, |
||||
|
)] |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
if use_cache: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test", [1, 1]), |
||||
|
("/test", [1, 1]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
else: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test", [1, 2]), |
||||
|
("/test", [1, 2]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=2, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_dependency_cache_in_same_endpoint( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory= DependencyFactory(dependency_style) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
||||
|
return dependency3 |
||||
|
|
||||
|
create_endpoint_3_annotations( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation1=Annotated[int, depends], |
||||
|
annotation2=Annotated[int, depends], |
||||
|
annotation3=Annotated[int, Depends(endpoint_dependency)] |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
if use_cache: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test", [1, 1, 1]), |
||||
|
("/test", [1, 1, 1]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
else: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test", [1, 2, 3]), |
||||
|
("/test", [1, 2, 3]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=3, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_dependency_cache_in_different_endpoints( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory= DependencyFactory(dependency_style) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
||||
|
return dependency3 |
||||
|
|
||||
|
create_endpoint_3_annotations( |
||||
|
router=router, |
||||
|
path="/test1", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation1=Annotated[int, depends], |
||||
|
annotation2=Annotated[int, depends], |
||||
|
annotation3=Annotated[int, Depends(endpoint_dependency)] |
||||
|
) |
||||
|
|
||||
|
create_endpoint_3_annotations( |
||||
|
router=router, |
||||
|
path="/test2", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation1=Annotated[int, depends], |
||||
|
annotation2=Annotated[int, depends], |
||||
|
annotation3=Annotated[int, Depends(endpoint_dependency)] |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
if use_cache: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test1", [1, 1, 1]), |
||||
|
("/test2", [1, 1, 1]), |
||||
|
("/test1", [1, 1, 1]), |
||||
|
("/test2", [1, 1, 1]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
else: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test1", [1, 2, 3]), |
||||
|
("/test2", [4, 5, 3]), |
||||
|
("/test1", [1, 2, 3]), |
||||
|
("/test2", [4, 5, 3]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=5, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_no_cached_dependency( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
is_websocket, |
||||
|
): |
||||
|
dependency_factory= DependencyFactory(dependency_style) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=False |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, depends], |
||||
|
expected_value=1 |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
urls_and_responses=[("/test", 1)] * 2, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("annotation", [ |
||||
|
Annotated[str, Path()], |
||||
|
Annotated[str, Body()], |
||||
|
Annotated[str, Query()], |
||||
|
Annotated[str, Header()], |
||||
|
SecurityScopes, |
||||
|
Annotated[str, Cookie()], |
||||
|
Annotated[str, Form()], |
||||
|
Annotated[str, File()], |
||||
|
BackgroundTasks, |
||||
|
]) |
||||
|
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( |
||||
|
annotation, |
||||
|
is_websocket |
||||
|
): |
||||
|
async def dependency_func(param: annotation) -> None: |
||||
|
yield |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
with pytest.raises(DependencyScopeConflict): |
||||
|
create_endpoint_1_annotation( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[ |
||||
|
None, |
||||
|
Depends(dependency_func, dependency_scope="lifespan") |
||||
|
], |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( |
||||
|
dependency_style: DependencyStyle, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
async def lifespan_scoped_dependency( |
||||
|
param: Annotated[int, Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan" |
||||
|
)] |
||||
|
) -> AsyncGenerator[int, None]: |
||||
|
yield param |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, Depends(lifespan_scoped_dependency)], |
||||
|
expected_value=1 |
||||
|
) |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
urls_and_responses=[("/test", 1)] * 2, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize([ |
||||
|
"dependency_style", |
||||
|
"supports_teardown" |
||||
|
], [ |
||||
|
(DependencyStyle.SYNC_FUNCTION, False), |
||||
|
(DependencyStyle.ASYNC_FUNCTION, False), |
||||
|
(DependencyStyle.SYNC_GENERATOR, True), |
||||
|
(DependencyStyle.ASYNC_GENERATOR, True), |
||||
|
]) |
||||
|
def test_the_same_dependency_can_work_in_different_scopes( |
||||
|
dependency_style: DependencyStyle, |
||||
|
supports_teardown, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
app = FastAPI() |
||||
|
|
||||
|
create_endpoint_2_annotations( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation1=Annotated[int, Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="endpoint" |
||||
|
)], |
||||
|
annotation2=Annotated[int, Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan" |
||||
|
)], |
||||
|
) |
||||
|
if is_websocket: |
||||
|
get_response = use_websocket |
||||
|
else: |
||||
|
get_response = use_endpoint |
||||
|
|
||||
|
assert dependency_factory.activation_times == 0 |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
with TestClient(app) as client: |
||||
|
assert dependency_factory.activation_times == 1 |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
assert get_response(client, "/test") == [2, 1] |
||||
|
assert dependency_factory.activation_times == 2 |
||||
|
if supports_teardown: |
||||
|
assert dependency_factory.deactivation_times == 1 |
||||
|
else: |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
assert get_response(client, "/test") == [3, 1] |
||||
|
assert dependency_factory.activation_times == 3 |
||||
|
if supports_teardown: |
||||
|
assert dependency_factory.deactivation_times == 2 |
||||
|
else: |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
assert dependency_factory.activation_times == 3 |
||||
|
if supports_teardown: |
||||
|
assert dependency_factory.deactivation_times == 3 |
||||
|
else: |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"]) |
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( |
||||
|
dependency_style: DependencyStyle, |
||||
|
is_websocket, |
||||
|
lifespan_style: Literal["lifespan_function", "lifespan_events"] |
||||
|
): |
||||
|
lifespan_started = False |
||||
|
lifespan_ended = False |
||||
|
if lifespan_style == "lifespan_generator": |
||||
|
@asynccontextmanager |
||||
|
async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, int], None]: |
||||
|
nonlocal lifespan_started |
||||
|
nonlocal lifespan_ended |
||||
|
lifespan_started = True |
||||
|
yield |
||||
|
lifespan_ended = True |
||||
|
|
||||
|
app = FastAPI(lifespan=lifespan) |
||||
|
elif lifespan_style == "events_decorator": |
||||
|
app = FastAPI() |
||||
|
with warnings.catch_warnings(action="ignore", category=DeprecationWarning): |
||||
|
@app.on_event("startup") |
||||
|
async def startup() -> None: |
||||
|
nonlocal lifespan_started |
||||
|
lifespan_started = True |
||||
|
|
||||
|
@app.on_event("shutdown") |
||||
|
async def shutdown() -> None: |
||||
|
nonlocal lifespan_ended |
||||
|
lifespan_ended = True |
||||
|
elif lifespan_style == "events_constructor": |
||||
|
async def startup() -> None: |
||||
|
nonlocal lifespan_started |
||||
|
lifespan_started = True |
||||
|
|
||||
|
async def shutdown() -> None: |
||||
|
nonlocal lifespan_ended |
||||
|
lifespan_ended = True |
||||
|
app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) |
||||
|
else: |
||||
|
assert_never(lifespan_style) |
||||
|
|
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan" |
||||
|
)], |
||||
|
expected_value=1 |
||||
|
) |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
urls_and_responses=[("/test", 1)] * 2, |
||||
|
is_websocket=is_websocket |
||||
|
) |
||||
|
assert lifespan_started and lifespan_ended |
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("depends_class", [Depends, Security]) |
||||
|
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( |
||||
|
depends_class, |
||||
|
is_websocket |
||||
|
): |
||||
|
async def sub_dependency() -> None: |
||||
|
pass |
||||
|
|
||||
|
async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: |
||||
|
yield |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
with pytest.raises(DependencyScopeConflict): |
||||
|
create_endpoint_1_annotation( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")], |
||||
|
) |
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_dependencies_must_provide_correct_dependency_scope( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory= DependencyFactory(dependency_style) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
with pytest.raises( |
||||
|
InvalidDependencyScope, |
||||
|
match=r'Dependency "value" of .* has an invalid scope: ' |
||||
|
r'"incorrect"' |
||||
|
): |
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[None, Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="incorrect", |
||||
|
use_cache=use_cache, |
||||
|
)] |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_endpoints_report_incorrect_dependency_scope( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory= DependencyFactory(dependency_style) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
# We intentionally change the dependency scope here to bypass the |
||||
|
# validation at the function level. |
||||
|
depends.dependency_scope = "asdad" |
||||
|
|
||||
|
with pytest.raises(InvalidDependencyScope): |
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, depends] |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_endpoints_report_uninitialized_dependency( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory= DependencyFactory(dependency_style) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, depends], |
||||
|
expected_value=1 |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router_endpoint": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
with TestClient(app) as client: |
||||
|
dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] |
||||
|
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {} |
||||
|
|
||||
|
try: |
||||
|
with pytest.raises(UninitializedLifespanDependency): |
||||
|
if is_websocket: |
||||
|
with client.websocket_connect("/test"): |
||||
|
pass |
||||
|
else: |
||||
|
client.post("/test") |
||||
|
finally: |
||||
|
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_endpoints_report_uninitialized_internal_lifespan( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, depends], |
||||
|
expected_value=1 |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router_endpoint": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
with TestClient(app) as client: |
||||
|
internal_state = client.app_state["__fastapi__"] |
||||
|
del client.app_state["__fastapi__"] |
||||
|
|
||||
|
try: |
||||
|
with pytest.raises(UninitializedLifespanDependency): |
||||
|
if is_websocket: |
||||
|
with client.websocket_connect("/test"): |
||||
|
pass |
||||
|
else: |
||||
|
client.post("/test") |
||||
|
finally: |
||||
|
client.app_state["__fastapi__"] = internal_state |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_bad_lifespan_scoped_dependencies( |
||||
|
use_cache, |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
is_websocket |
||||
|
): |
||||
|
dependency_factory= DependencyFactory(dependency_style, should_error=True) |
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, depends], |
||||
|
expected_value=1 |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router_endpoint": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
with pytest.raises(IntentionallyBadDependency) as exception_info: |
||||
|
with TestClient(app): |
||||
|
pass |
||||
|
|
||||
|
assert exception_info.value.args == (1,) |
@ -0,0 +1,202 @@ |
|||||
|
from enum import StrEnum, auto |
||||
|
from typing import Any, AsyncGenerator, Generator, TypeVar, Union, assert_never |
||||
|
|
||||
|
from fastapi import APIRouter, FastAPI, WebSocket |
||||
|
from starlette.testclient import TestClient |
||||
|
from starlette.websockets import WebSocketDisconnect |
||||
|
|
||||
|
T = TypeVar('T') |
||||
|
|
||||
|
|
||||
|
class DependencyStyle(StrEnum): |
||||
|
SYNC_FUNCTION = auto() |
||||
|
ASYNC_FUNCTION = auto() |
||||
|
SYNC_GENERATOR = auto() |
||||
|
ASYNC_GENERATOR = auto() |
||||
|
|
||||
|
|
||||
|
class IntentionallyBadDependency(Exception): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
class DependencyFactory: |
||||
|
def __init__( |
||||
|
self, |
||||
|
dependency_style: DependencyStyle, *, |
||||
|
should_error: bool = False, |
||||
|
value_offset: int = 0, |
||||
|
): |
||||
|
self.activation_times = 0 |
||||
|
self.deactivation_times = 0 |
||||
|
self.dependency_style = dependency_style |
||||
|
self._should_error = should_error |
||||
|
self._value_offset = value_offset |
||||
|
|
||||
|
def get_dependency(self): |
||||
|
if self.dependency_style == DependencyStyle.SYNC_FUNCTION: |
||||
|
return self._synchronous_function_dependency |
||||
|
|
||||
|
if self.dependency_style == DependencyStyle.SYNC_GENERATOR: |
||||
|
return self._synchronous_generator_dependency |
||||
|
|
||||
|
if self.dependency_style == DependencyStyle.ASYNC_FUNCTION: |
||||
|
return self._asynchronous_function_dependency |
||||
|
|
||||
|
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: |
||||
|
return self._asynchronous_generator_dependency |
||||
|
|
||||
|
assert_never(self.dependency_style) |
||||
|
|
||||
|
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: |
||||
|
self.activation_times += 1 |
||||
|
if self._should_error: |
||||
|
raise IntentionallyBadDependency(self.activation_times) |
||||
|
|
||||
|
yield self.activation_times + self._value_offset |
||||
|
self.deactivation_times += 1 |
||||
|
|
||||
|
def _synchronous_generator_dependency(self) -> Generator[T, None, None]: |
||||
|
self.activation_times += 1 |
||||
|
if self._should_error: |
||||
|
raise IntentionallyBadDependency(self.activation_times) |
||||
|
|
||||
|
yield self.activation_times + self._value_offset |
||||
|
self.deactivation_times += 1 |
||||
|
|
||||
|
async def _asynchronous_function_dependency(self) -> T: |
||||
|
self.activation_times += 1 |
||||
|
if self._should_error: |
||||
|
raise IntentionallyBadDependency(self.activation_times) |
||||
|
|
||||
|
return self.activation_times + self._value_offset |
||||
|
|
||||
|
def _synchronous_function_dependency(self) -> T: |
||||
|
self.activation_times += 1 |
||||
|
if self._should_error: |
||||
|
raise IntentionallyBadDependency(self.activation_times) |
||||
|
|
||||
|
return self.activation_times + self._value_offset |
||||
|
|
||||
|
|
||||
|
def use_endpoint(client: TestClient, url: str) -> Any: |
||||
|
response = client.post(url) |
||||
|
response.raise_for_status() |
||||
|
return response.json() |
||||
|
|
||||
|
|
||||
|
def use_websocket(client: TestClient, url: str) -> Any: |
||||
|
with client.websocket_connect(url) as connection: |
||||
|
return connection.receive_json() |
||||
|
|
||||
|
|
||||
|
def create_endpoint_0_annotations( |
||||
|
*, |
||||
|
router: Union[APIRouter, FastAPI], |
||||
|
path: str, |
||||
|
is_websocket: bool, |
||||
|
) -> None: |
||||
|
if is_websocket: |
||||
|
@router.websocket(path) |
||||
|
async def endpoint(websocket: WebSocket) -> None: |
||||
|
await websocket.accept() |
||||
|
try: |
||||
|
await websocket.send_json(None) |
||||
|
except WebSocketDisconnect: |
||||
|
pass |
||||
|
else: |
||||
|
@router.post(path) |
||||
|
async def endpoint() -> None: |
||||
|
return None |
||||
|
|
||||
|
|
||||
|
def create_endpoint_1_annotation( |
||||
|
*, |
||||
|
router: Union[APIRouter, FastAPI], |
||||
|
path: str, |
||||
|
is_websocket: bool, |
||||
|
annotation: Any, |
||||
|
expected_value: Any = None |
||||
|
) -> None: |
||||
|
if is_websocket: |
||||
|
@router.websocket(path) |
||||
|
async def endpoint( |
||||
|
websocket: WebSocket, |
||||
|
value: annotation |
||||
|
) -> None: |
||||
|
if expected_value is not None: |
||||
|
assert value == expected_value |
||||
|
|
||||
|
await websocket.accept() |
||||
|
try: |
||||
|
await websocket.send_json(value) |
||||
|
except WebSocketDisconnect: |
||||
|
pass |
||||
|
else: |
||||
|
@router.post(path) |
||||
|
async def endpoint( |
||||
|
value: annotation |
||||
|
) -> None: |
||||
|
if expected_value is not None: |
||||
|
assert value == expected_value |
||||
|
|
||||
|
return value |
||||
|
|
||||
|
def create_endpoint_2_annotations( |
||||
|
*, |
||||
|
router: Union[APIRouter, FastAPI], |
||||
|
path: str, |
||||
|
is_websocket: bool, |
||||
|
annotation1: Any, |
||||
|
annotation2: Any, |
||||
|
) -> None: |
||||
|
if is_websocket: |
||||
|
@router.websocket(path) |
||||
|
async def endpoint( |
||||
|
websocket: WebSocket, |
||||
|
value1: annotation1, |
||||
|
value2: annotation2, |
||||
|
) -> None: |
||||
|
await websocket.accept() |
||||
|
try: |
||||
|
await websocket.send_json([value1, value2]) |
||||
|
except WebSocketDisconnect: |
||||
|
await websocket.close() |
||||
|
else: |
||||
|
@router.post(path) |
||||
|
async def endpoint( |
||||
|
value1: annotation1, |
||||
|
value2: annotation2, |
||||
|
) -> list[Any]: |
||||
|
return [value1, value2] |
||||
|
|
||||
|
|
||||
|
def create_endpoint_3_annotations( |
||||
|
*, |
||||
|
router: Union[APIRouter, FastAPI], |
||||
|
path: str, |
||||
|
is_websocket: bool, |
||||
|
annotation1: Any, |
||||
|
annotation2: Any, |
||||
|
annotation3: Any |
||||
|
) -> None: |
||||
|
if is_websocket: |
||||
|
@router.websocket(path) |
||||
|
async def endpoint( |
||||
|
websocket: WebSocket, |
||||
|
value1: annotation1, |
||||
|
value2: annotation2, |
||||
|
value3: annotation3 |
||||
|
) -> None: |
||||
|
await websocket.accept() |
||||
|
try: |
||||
|
await websocket.send_json([value1, value2, value3]) |
||||
|
except WebSocketDisconnect: |
||||
|
await websocket.close() |
||||
|
else: |
||||
|
@router.post(path) |
||||
|
async def endpoint( |
||||
|
value1: annotation1, |
||||
|
value2: annotation2, |
||||
|
value3: annotation3 |
||||
|
) -> list[Any]: |
||||
|
return [value1, value2, value3] |
Loading…
Reference in new issue