11 changed files with 1834 additions and 794 deletions
@ -1,703 +0,0 @@ |
|||
from enum import StrEnum, auto |
|||
from typing import Any, AsyncGenerator, List, Tuple, TypeVar |
|||
|
|||
import pytest |
|||
from fastapi import ( |
|||
APIRouter, |
|||
BackgroundTasks, |
|||
Body, |
|||
Cookie, |
|||
Depends, |
|||
FastAPI, |
|||
File, |
|||
Form, |
|||
Header, |
|||
Path, |
|||
Query, |
|||
) |
|||
from fastapi.exceptions import FastAPIError |
|||
from fastapi.params import Security |
|||
from fastapi.security import SecurityScopes |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Annotated, Generator, Literal, assert_never |
|||
|
|||
T = TypeVar('T') |
|||
|
|||
|
|||
class DependencyStyle(StrEnum): |
|||
SYNC_FUNCTION = auto() |
|||
ASYNC_FUNCTION = auto() |
|||
SYNC_GENERATOR = auto() |
|||
ASYNC_GENERATOR = auto() |
|||
|
|||
|
|||
class DependencyFactory: |
|||
def __init__( |
|||
self, |
|||
dependency_style: DependencyStyle, *, |
|||
should_error: bool = False |
|||
): |
|||
self.activation_times = 0 |
|||
self.deactivation_times = 0 |
|||
self.dependency_style = dependency_style |
|||
self._should_error = should_error |
|||
|
|||
def get_dependency(self): |
|||
if self.dependency_style == DependencyStyle.SYNC_FUNCTION: |
|||
return self._synchronous_function_dependency |
|||
|
|||
if self.dependency_style == DependencyStyle.SYNC_GENERATOR: |
|||
return self._synchronous_generator_dependency |
|||
|
|||
if self.dependency_style == DependencyStyle.ASYNC_FUNCTION: |
|||
return self._asynchronous_function_dependency |
|||
|
|||
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: |
|||
return self._asynchronous_generator_dependency |
|||
|
|||
assert_never(self.dependency_style) |
|||
|
|||
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise ValueError(self.activation_times) |
|||
|
|||
yield self.activation_times |
|||
self.deactivation_times += 1 |
|||
|
|||
def _synchronous_generator_dependency(self) -> Generator[T, None, None]: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise ValueError(self.activation_times) |
|||
|
|||
yield self.activation_times |
|||
self.deactivation_times += 1 |
|||
|
|||
async def _asynchronous_function_dependency(self) -> T: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise ValueError(self.activation_times) |
|||
|
|||
return self.activation_times |
|||
|
|||
def _synchronous_function_dependency(self) -> T: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise ValueError(self.activation_times) |
|||
|
|||
return self.activation_times |
|||
|
|||
|
|||
def _expect_correct_amount_of_dependency_activations( |
|||
*, |
|||
app: FastAPI, |
|||
dependency_factory: DependencyFactory, |
|||
urls_and_responses: List[Tuple[str, Any]], |
|||
expected_activation_times: int |
|||
) -> None: |
|||
assert dependency_factory.activation_times == 0 |
|||
assert dependency_factory.deactivation_times == 0 |
|||
with TestClient(app) as client: |
|||
assert dependency_factory.activation_times == expected_activation_times |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
for url, expected_response in urls_and_responses: |
|||
response = client.post(url) |
|||
response.raise_for_status() |
|||
assert response.json() == expected_response |
|||
|
|||
assert dependency_factory.activation_times == expected_activation_times |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
assert dependency_factory.activation_times == expected_activation_times |
|||
if dependency_factory.dependency_style not in ( |
|||
DependencyStyle.SYNC_FUNCTION, |
|||
DependencyStyle.ASYNC_FUNCTION |
|||
): |
|||
assert dependency_factory.deactivation_times == expected_activation_times |
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, use_cache): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[None, Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
)] |
|||
) -> None: |
|||
assert dependency == 1 |
|||
return dependency |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
urls_and_responses=[("/test", 1)] * 2, |
|||
expected_activation_times=1 |
|||
) |
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_router_dependencies( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
if routing_style == "app": |
|||
app = FastAPI(dependencies=[depends]) |
|||
|
|||
@app.post("/test") |
|||
async def endpoint() -> None: |
|||
return None |
|||
else: |
|||
app = FastAPI() |
|||
router = APIRouter(dependencies=[depends]) |
|||
|
|||
@router.post("/test") |
|||
async def endpoint() -> None: |
|||
return None |
|||
|
|||
app.include_router(router) |
|||
|
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
urls_and_responses=[("/test", None)] * 2, |
|||
expected_activation_times=1 |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) |
|||
def test_dependency_cache_in_same_dependency( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
main_dependency_scope: Literal["endpoint", "lifespan"] |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def dependency( |
|||
sub_dependency1: Annotated[int, depends], |
|||
sub_dependency2: Annotated[int, depends], |
|||
) -> List[int]: |
|||
return [sub_dependency1, sub_dependency2] |
|||
|
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[List[int], Depends( |
|||
dependency, |
|||
use_cache=use_cache, |
|||
dependency_scope=main_dependency_scope, |
|||
)] |
|||
) -> List[int]: |
|||
return dependency |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
if use_cache: |
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [1, 1]), |
|||
("/test", [1, 1]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1 |
|||
) |
|||
else: |
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [1, 2]), |
|||
("/test", [1, 2]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=2 |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_dependency_cache_in_same_endpoint( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
|||
return dependency3 |
|||
|
|||
@router.post("/test1") |
|||
async def endpoint( |
|||
dependency1: Annotated[int, depends], |
|||
dependency2: Annotated[int, depends], |
|||
dependency3: Annotated[int, Depends(endpoint_dependency)] |
|||
) -> List[int]: |
|||
return [dependency1, dependency2, dependency3] |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
if use_cache: |
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [1, 1, 1]), |
|||
("/test1", [1, 1, 1]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1 |
|||
) |
|||
else: |
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [1, 2, 3]), |
|||
("/test1", [1, 2, 3]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=3 |
|||
) |
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_dependency_cache_in_different_endpoints( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
|||
return dependency3 |
|||
|
|||
@router.post("/test1") |
|||
async def endpoint( |
|||
dependency1: Annotated[int, depends], |
|||
dependency2: Annotated[int, depends], |
|||
dependency3: Annotated[int, Depends(endpoint_dependency)] |
|||
) -> List[int]: |
|||
return [dependency1, dependency2, dependency3] |
|||
|
|||
@router.post("/test2") |
|||
async def endpoint2( |
|||
dependency1: Annotated[int, depends], |
|||
dependency2: Annotated[int, depends], |
|||
dependency3: Annotated[int, Depends(endpoint_dependency)] |
|||
) -> List[int]: |
|||
return [dependency1, dependency2, dependency3] |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
if use_cache: |
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [1, 1, 1]), |
|||
("/test2", [1, 1, 1]), |
|||
("/test1", [1, 1, 1]), |
|||
("/test2", [1, 1, 1]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1 |
|||
) |
|||
else: |
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [1, 2, 3]), |
|||
("/test2", [4, 5, 3]), |
|||
("/test1", [1, 2, 3]), |
|||
("/test2", [4, 5, 3]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=5 |
|||
) |
|||
|
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_no_cached_dependency( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=False |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[int, depends], |
|||
) -> int: |
|||
return dependency |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
urls_and_responses=[("/test", 1)] * 2, |
|||
expected_activation_times=1 |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("annotation", [ |
|||
Annotated[str, Path()], |
|||
Annotated[str, Body()], |
|||
Annotated[str, Query()], |
|||
Annotated[str, Header()], |
|||
SecurityScopes, |
|||
Annotated[str, Cookie()], |
|||
Annotated[str, Form()], |
|||
Annotated[str, File()], |
|||
BackgroundTasks, |
|||
]) |
|||
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( |
|||
annotation |
|||
): |
|||
async def dependency_func(param: annotation) -> None: |
|||
yield |
|||
|
|||
app = FastAPI() |
|||
|
|||
with pytest.raises(FastAPIError): |
|||
@app.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[ |
|||
None, Depends(dependency_func, dependency_scope="lifespan")] |
|||
) -> None: |
|||
return |
|||
|
|||
|
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( |
|||
dependency_style: DependencyStyle |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
async def lifespan_scoped_dependency( |
|||
param: Annotated[int, Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan" |
|||
)] |
|||
) -> AsyncGenerator[int, None]: |
|||
yield param |
|||
|
|||
app = FastAPI() |
|||
|
|||
@app.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[int, Depends( |
|||
lifespan_scoped_dependency, |
|||
dependency_scope="lifespan" |
|||
)] |
|||
) -> int: |
|||
return dependency |
|||
|
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1, |
|||
urls_and_responses=[("/test", 1)] * 2 |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("depends_class", [Depends, Security]) |
|||
@pytest.mark.parametrize("route_type", [FastAPI.post, FastAPI.websocket], ids=[ |
|||
"websocket", "endpoint" |
|||
]) |
|||
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( |
|||
depends_class, |
|||
route_type |
|||
): |
|||
async def sub_dependency() -> None: |
|||
pass |
|||
|
|||
async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: |
|||
yield |
|||
|
|||
app = FastAPI() |
|||
route_decorator = route_type(app, "/test") |
|||
|
|||
with pytest.raises(FastAPIError): |
|||
@route_decorator |
|||
async def endpoint(x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] |
|||
) -> None: |
|||
return |
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_dependencies_must_provide_correct_dependency_scope( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
with pytest.raises(FastAPIError): |
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[None, Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="incorrect", |
|||
use_cache=use_cache, |
|||
)] |
|||
) -> None: |
|||
assert dependency == 1 |
|||
return dependency |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoints_report_incorrect_dependency_scope( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
# We intentionally change the dependency scope here to bypass the |
|||
# validation at the function level. |
|||
depends.dependency_scope = "asdad" |
|||
|
|||
with pytest.raises(FastAPIError): |
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[int, depends] |
|||
) -> int: |
|||
assert dependency == 1 |
|||
return dependency |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoints_report_uninitialized_dependency( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[int, depends] |
|||
) -> int: |
|||
assert dependency == 1 |
|||
return dependency |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
with TestClient(app) as client: |
|||
dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] |
|||
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {} |
|||
|
|||
try: |
|||
with pytest.raises(FastAPIError): |
|||
client.post("/test") |
|||
finally: |
|||
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoints_report_uninitialized_internal_lifespan( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[int, depends] |
|||
) -> int: |
|||
assert dependency == 1 |
|||
return dependency |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
with TestClient(app) as client: |
|||
internal_state = client.app_state["__fastapi__"] |
|||
del client.app_state["__fastapi__"] |
|||
|
|||
try: |
|||
with pytest.raises(FastAPIError): |
|||
client.post("/test") |
|||
finally: |
|||
client.app_state["__fastapi__"] = internal_state |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: DependencyStyle, routing_style): |
|||
dependency_factory= DependencyFactory(dependency_style, should_error=True) |
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[int, depends] |
|||
) -> int: |
|||
assert dependency == 1 |
|||
return dependency |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
with pytest.raises(ValueError) as exception_info: |
|||
with TestClient(app): |
|||
pass |
|||
|
|||
assert exception_info.value.args == (1,) |
|||
|
|||
|
|||
# TODO: Add tests for dependency_overrides |
|||
# TODO: Add a websocket equivalent to all tests |
@ -0,0 +1,634 @@ |
|||
from typing import Any, AsyncGenerator, List, Tuple |
|||
|
|||
import pytest |
|||
from fastapi import ( |
|||
APIRouter, |
|||
BackgroundTasks, |
|||
Body, |
|||
Cookie, |
|||
Depends, |
|||
FastAPI, |
|||
File, |
|||
Form, |
|||
Header, |
|||
Path, |
|||
Query, |
|||
) |
|||
from fastapi.exceptions import DependencyScopeConflict |
|||
from fastapi.params import Security |
|||
from fastapi.security import SecurityScopes |
|||
from fastapi.testclient import TestClient |
|||
from typing_extensions import Annotated, Literal |
|||
|
|||
from tests.test_lifespan_scoped_dependencies.testing_utilities import ( |
|||
DependencyFactory, |
|||
DependencyStyle, |
|||
IntentionallyBadDependency, |
|||
create_endpoint_0_annotations, |
|||
create_endpoint_1_annotation, |
|||
create_endpoint_3_annotations, |
|||
use_endpoint, |
|||
use_websocket, |
|||
) |
|||
|
|||
|
|||
def expect_correct_amount_of_dependency_activations( |
|||
*, |
|||
app: FastAPI, |
|||
dependency_factory: DependencyFactory, |
|||
override_dependency_factory: DependencyFactory, |
|||
urls_and_responses: List[Tuple[str, Any]], |
|||
expected_activation_times: int, |
|||
is_websocket: bool |
|||
) -> None: |
|||
assert dependency_factory.activation_times == 0 |
|||
assert dependency_factory.deactivation_times == 0 |
|||
assert override_dependency_factory.activation_times == 0 |
|||
assert override_dependency_factory.deactivation_times == 0 |
|||
|
|||
with TestClient(app) as client: |
|||
assert dependency_factory.activation_times == 0 |
|||
assert dependency_factory.deactivation_times == 0 |
|||
assert override_dependency_factory.activation_times == expected_activation_times |
|||
assert override_dependency_factory.deactivation_times == 0 |
|||
|
|||
for url, expected_response in urls_and_responses: |
|||
if is_websocket: |
|||
response = use_websocket(client, url) |
|||
else: |
|||
response = use_endpoint(client, url) |
|||
|
|||
assert response == expected_response |
|||
|
|||
assert dependency_factory.activation_times == 0 |
|||
assert dependency_factory.deactivation_times == 0 |
|||
assert override_dependency_factory.activation_times == expected_activation_times |
|||
assert override_dependency_factory.deactivation_times == 0 |
|||
|
|||
assert dependency_factory.activation_times == 0 |
|||
assert override_dependency_factory.activation_times == expected_activation_times |
|||
if dependency_factory.dependency_style not in ( |
|||
DependencyStyle.SYNC_FUNCTION, |
|||
DependencyStyle.ASYNC_FUNCTION |
|||
): |
|||
assert dependency_factory.deactivation_times == 0 |
|||
assert override_dependency_factory.deactivation_times == expected_activation_times |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoint_dependencies( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory( |
|||
dependency_style, |
|||
value_offset=10 |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[ |
|||
None, |
|||
Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
], |
|||
expected_value=11 |
|||
) |
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
urls_and_responses=[("/test", 11)] * 2, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("dependency_duplication", [1, 2]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_router_dependencies( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
dependency_duplication, |
|||
is_websocket |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory( |
|||
dependency_style, |
|||
value_offset=10 |
|||
) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
if routing_style == "app": |
|||
app = FastAPI(dependencies=[depends] * dependency_duplication) |
|||
|
|||
create_endpoint_0_annotations( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket |
|||
) |
|||
else: |
|||
app = FastAPI() |
|||
router = APIRouter(dependencies=[depends] * dependency_duplication) |
|||
|
|||
create_endpoint_0_annotations( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
app.include_router(router) |
|||
|
|||
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
urls_and_responses=[("/test", None)] * 2, |
|||
expected_activation_times=1 if use_cache else dependency_duplication, |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) |
|||
def test_dependency_cache_in_same_dependency( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
main_dependency_scope: Literal["endpoint", "lifespan"], |
|||
is_websocket |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory( |
|||
dependency_style, |
|||
value_offset=10 |
|||
) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def dependency( |
|||
sub_dependency1: Annotated[int, depends], |
|||
sub_dependency2: Annotated[int, depends], |
|||
) -> List[int]: |
|||
return [sub_dependency1, sub_dependency2] |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[List[int], Depends( |
|||
dependency, |
|||
use_cache=use_cache, |
|||
dependency_scope=main_dependency_scope, |
|||
)] |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
app.dependency_overrides[ |
|||
dependency_factory.get_dependency() |
|||
] = override_dependency_factory.get_dependency() |
|||
|
|||
if use_cache: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [11, 11]), |
|||
("/test", [11, 11]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket |
|||
) |
|||
else: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [11, 12]), |
|||
("/test", [11, 12]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
expected_activation_times=2, |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_dependency_cache_in_same_endpoint( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
is_websocket |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory( |
|||
dependency_style, |
|||
value_offset=10 |
|||
) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
|||
return dependency3 |
|||
|
|||
create_endpoint_3_annotations( |
|||
router=router, |
|||
path="/test1", |
|||
is_websocket=is_websocket, |
|||
annotation1=Annotated[int, depends], |
|||
annotation2=Annotated[int, depends], |
|||
annotation3=Annotated[int, Depends(endpoint_dependency)], |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
app.dependency_overrides[ |
|||
dependency_factory.get_dependency() |
|||
] = override_dependency_factory.get_dependency() |
|||
|
|||
if use_cache: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [11, 11, 11]), |
|||
("/test1", [11, 11, 11]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket |
|||
) |
|||
else: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [11, 12, 13]), |
|||
("/test1", [11, 12, 13]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
expected_activation_times=3, |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_dependency_cache_in_different_endpoints( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
is_websocket |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory( |
|||
dependency_style, |
|||
value_offset=10 |
|||
) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
|||
return dependency3 |
|||
|
|||
create_endpoint_3_annotations( |
|||
router=router, |
|||
path="/test1", |
|||
is_websocket=is_websocket, |
|||
annotation1=Annotated[int, depends], |
|||
annotation2=Annotated[int, depends], |
|||
annotation3=Annotated[int, Depends(endpoint_dependency)], |
|||
) |
|||
|
|||
create_endpoint_3_annotations( |
|||
router=router, |
|||
path="/test2", |
|||
is_websocket=is_websocket, |
|||
annotation1=Annotated[int, depends], |
|||
annotation2=Annotated[int, depends], |
|||
annotation3=Annotated[int, Depends(endpoint_dependency)], |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
app.dependency_overrides[ |
|||
dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() |
|||
|
|||
if use_cache: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [11, 11, 11]), |
|||
("/test2", [11, 11, 11]), |
|||
("/test1", [11, 11, 11]), |
|||
("/test2", [11, 11, 11]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket |
|||
) |
|||
else: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [11, 12, 13]), |
|||
("/test2", [14, 15, 13]), |
|||
("/test1", [11, 12, 13]), |
|||
("/test2", [14, 15, 13]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
expected_activation_times=5, |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_no_cached_dependency( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
is_websocket |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory( |
|||
dependency_style, |
|||
value_offset=10 |
|||
) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=False |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, depends], |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
app.dependency_overrides[ |
|||
dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
urls_and_responses=[("/test", 11)] * 2, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("annotation", [ |
|||
Annotated[str, Path()], |
|||
Annotated[str, Body()], |
|||
Annotated[str, Query()], |
|||
Annotated[str, Header()], |
|||
SecurityScopes, |
|||
Annotated[str, Cookie()], |
|||
Annotated[str, Form()], |
|||
Annotated[str, File()], |
|||
BackgroundTasks, |
|||
]) |
|||
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( |
|||
annotation, |
|||
is_websocket |
|||
): |
|||
async def dependency_func() -> None: |
|||
yield |
|||
|
|||
async def override_dependency_func(param: annotation) -> None: |
|||
yield |
|||
|
|||
app = FastAPI() |
|||
app.dependency_overrides[dependency_func] = override_dependency_func |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[None, |
|||
Depends(dependency_func, dependency_scope="lifespan") |
|||
] |
|||
) |
|||
|
|||
with pytest.raises(DependencyScopeConflict): |
|||
with TestClient(app): |
|||
pass |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies( |
|||
dependency_style: DependencyStyle, |
|||
is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory( |
|||
dependency_style, |
|||
value_offset=10 |
|||
) |
|||
|
|||
async def lifespan_scoped_dependency( |
|||
param: Annotated[int, Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan" |
|||
)] |
|||
) -> AsyncGenerator[int, None]: |
|||
yield param |
|||
|
|||
app = FastAPI() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[ |
|||
int, |
|||
Depends(lifespan_scoped_dependency, dependency_scope="lifespan") |
|||
], |
|||
) |
|||
|
|||
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
expected_activation_times=1, |
|||
urls_and_responses=[("/test", 11)] * 2, |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("depends_class", [Depends, Security]) |
|||
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( |
|||
depends_class, |
|||
is_websocket |
|||
): |
|||
async def sub_dependency() -> None: |
|||
pass |
|||
|
|||
async def dependency_func() -> None: |
|||
yield |
|||
|
|||
async def override_dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: |
|||
yield |
|||
|
|||
app = FastAPI() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] |
|||
) |
|||
|
|||
app.dependency_overrides[dependency_func] = override_dependency_func |
|||
|
|||
with pytest.raises(DependencyScopeConflict): |
|||
with TestClient(app): |
|||
pass |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_bad_override_lifespan_scoped_dependencies( |
|||
use_cache, |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
is_websocket |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory(dependency_style, should_error=True) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, depends] |
|||
) |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency() |
|||
|
|||
with pytest.raises(IntentionallyBadDependency) as exception_info: |
|||
with TestClient(app): |
|||
pass |
|||
|
|||
assert exception_info.value.args == (1,) |
@ -0,0 +1,854 @@ |
|||
import warnings |
|||
from contextlib import asynccontextmanager |
|||
from typing import Any, AsyncGenerator, Dict, List, Tuple |
|||
|
|||
import pytest |
|||
from fastapi import ( |
|||
APIRouter, |
|||
BackgroundTasks, |
|||
Body, |
|||
Cookie, |
|||
Depends, |
|||
FastAPI, |
|||
File, |
|||
Form, |
|||
Header, |
|||
Path, |
|||
Query, |
|||
) |
|||
from fastapi.exceptions import ( |
|||
DependencyScopeConflict, |
|||
InvalidDependencyScope, |
|||
UninitializedLifespanDependency, |
|||
) |
|||
from fastapi.params import Security |
|||
from fastapi.security import SecurityScopes |
|||
from fastapi.testclient import TestClient |
|||
from typing_extensions import Annotated, Literal, assert_never |
|||
|
|||
from tests.test_lifespan_scoped_dependencies.testing_utilities import ( |
|||
DependencyFactory, |
|||
DependencyStyle, |
|||
IntentionallyBadDependency, |
|||
create_endpoint_0_annotations, |
|||
create_endpoint_1_annotation, |
|||
create_endpoint_2_annotations, |
|||
create_endpoint_3_annotations, |
|||
use_endpoint, |
|||
use_websocket, |
|||
) |
|||
|
|||
|
|||
def expect_correct_amount_of_dependency_activations( |
|||
*, |
|||
app: FastAPI, |
|||
dependency_factory: DependencyFactory, |
|||
urls_and_responses: List[Tuple[str, Any]], |
|||
expected_activation_times: int, |
|||
is_websocket: bool |
|||
) -> None: |
|||
assert dependency_factory.activation_times == 0 |
|||
assert dependency_factory.deactivation_times == 0 |
|||
with TestClient(app) as client: |
|||
assert dependency_factory.activation_times == expected_activation_times |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
for url, expected_response in urls_and_responses: |
|||
if is_websocket: |
|||
assert use_websocket(client, url) == expected_response |
|||
else: |
|||
assert use_endpoint(client, url) == expected_response |
|||
|
|||
assert dependency_factory.activation_times == expected_activation_times |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
assert dependency_factory.activation_times == expected_activation_times |
|||
if dependency_factory.dependency_style not in ( |
|||
DependencyStyle.SYNC_FUNCTION, |
|||
DependencyStyle.ASYNC_FUNCTION |
|||
): |
|||
assert dependency_factory.deactivation_times == expected_activation_times |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False], ids=["With Cache", "Without Cache"]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoint_dependencies( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
is_websocket: bool, |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[None, Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
)], |
|||
expected_value=1 |
|||
) |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
urls_and_responses=[("/test", 1)] * 2, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("dependency_duplication", [1, 2]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_router_dependencies( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
dependency_duplication, |
|||
is_websocket: bool, |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
if routing_style == "app": |
|||
app = FastAPI(dependencies=[depends] * dependency_duplication) |
|||
|
|||
create_endpoint_0_annotations( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket |
|||
) |
|||
else: |
|||
app = FastAPI() |
|||
router = APIRouter(dependencies=[depends] * dependency_duplication) |
|||
|
|||
create_endpoint_0_annotations( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
app.include_router(router) |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
urls_and_responses=[("/test", None)] * 2, |
|||
expected_activation_times=1 if use_cache else dependency_duplication, |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) |
|||
def test_dependency_cache_in_same_dependency( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
main_dependency_scope: Literal["endpoint", "lifespan"], |
|||
is_websocket: bool, |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def dependency( |
|||
sub_dependency1: Annotated[int, depends], |
|||
sub_dependency2: Annotated[int, depends], |
|||
) -> List[int]: |
|||
return [sub_dependency1, sub_dependency2] |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[List[int], Depends( |
|||
dependency, |
|||
use_cache=use_cache, |
|||
dependency_scope=main_dependency_scope, |
|||
)] |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
if use_cache: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [1, 1]), |
|||
("/test", [1, 1]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket |
|||
) |
|||
else: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [1, 2]), |
|||
("/test", [1, 2]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=2, |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_dependency_cache_in_same_endpoint( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
is_websocket |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
|||
return dependency3 |
|||
|
|||
create_endpoint_3_annotations( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation1=Annotated[int, depends], |
|||
annotation2=Annotated[int, depends], |
|||
annotation3=Annotated[int, Depends(endpoint_dependency)] |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
if use_cache: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [1, 1, 1]), |
|||
("/test", [1, 1, 1]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket |
|||
) |
|||
else: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [1, 2, 3]), |
|||
("/test", [1, 2, 3]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=3, |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_dependency_cache_in_different_endpoints( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
is_websocket |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
|||
return dependency3 |
|||
|
|||
create_endpoint_3_annotations( |
|||
router=router, |
|||
path="/test1", |
|||
is_websocket=is_websocket, |
|||
annotation1=Annotated[int, depends], |
|||
annotation2=Annotated[int, depends], |
|||
annotation3=Annotated[int, Depends(endpoint_dependency)] |
|||
) |
|||
|
|||
create_endpoint_3_annotations( |
|||
router=router, |
|||
path="/test2", |
|||
is_websocket=is_websocket, |
|||
annotation1=Annotated[int, depends], |
|||
annotation2=Annotated[int, depends], |
|||
annotation3=Annotated[int, Depends(endpoint_dependency)] |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
if use_cache: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [1, 1, 1]), |
|||
("/test2", [1, 1, 1]), |
|||
("/test1", [1, 1, 1]), |
|||
("/test2", [1, 1, 1]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket |
|||
) |
|||
else: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [1, 2, 3]), |
|||
("/test2", [4, 5, 3]), |
|||
("/test1", [1, 2, 3]), |
|||
("/test2", [4, 5, 3]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=5, |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_no_cached_dependency( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
is_websocket, |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=False |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, depends], |
|||
expected_value=1 |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
urls_and_responses=[("/test", 1)] * 2, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("annotation", [ |
|||
Annotated[str, Path()], |
|||
Annotated[str, Body()], |
|||
Annotated[str, Query()], |
|||
Annotated[str, Header()], |
|||
SecurityScopes, |
|||
Annotated[str, Cookie()], |
|||
Annotated[str, Form()], |
|||
Annotated[str, File()], |
|||
BackgroundTasks, |
|||
]) |
|||
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( |
|||
annotation, |
|||
is_websocket |
|||
): |
|||
async def dependency_func(param: annotation) -> None: |
|||
yield |
|||
|
|||
app = FastAPI() |
|||
|
|||
with pytest.raises(DependencyScopeConflict): |
|||
create_endpoint_1_annotation( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[ |
|||
None, |
|||
Depends(dependency_func, dependency_scope="lifespan") |
|||
], |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( |
|||
dependency_style: DependencyStyle, |
|||
is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
async def lifespan_scoped_dependency( |
|||
param: Annotated[int, Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan" |
|||
)] |
|||
) -> AsyncGenerator[int, None]: |
|||
yield param |
|||
|
|||
app = FastAPI() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, Depends(lifespan_scoped_dependency)], |
|||
expected_value=1 |
|||
) |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1, |
|||
urls_and_responses=[("/test", 1)] * 2, |
|||
is_websocket=is_websocket |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize([ |
|||
"dependency_style", |
|||
"supports_teardown" |
|||
], [ |
|||
(DependencyStyle.SYNC_FUNCTION, False), |
|||
(DependencyStyle.ASYNC_FUNCTION, False), |
|||
(DependencyStyle.SYNC_GENERATOR, True), |
|||
(DependencyStyle.ASYNC_GENERATOR, True), |
|||
]) |
|||
def test_the_same_dependency_can_work_in_different_scopes( |
|||
dependency_style: DependencyStyle, |
|||
supports_teardown, |
|||
is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
app = FastAPI() |
|||
|
|||
create_endpoint_2_annotations( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation1=Annotated[int, Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="endpoint" |
|||
)], |
|||
annotation2=Annotated[int, Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan" |
|||
)], |
|||
) |
|||
if is_websocket: |
|||
get_response = use_websocket |
|||
else: |
|||
get_response = use_endpoint |
|||
|
|||
assert dependency_factory.activation_times == 0 |
|||
assert dependency_factory.deactivation_times == 0 |
|||
with TestClient(app) as client: |
|||
assert dependency_factory.activation_times == 1 |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
assert get_response(client, "/test") == [2, 1] |
|||
assert dependency_factory.activation_times == 2 |
|||
if supports_teardown: |
|||
assert dependency_factory.deactivation_times == 1 |
|||
else: |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
assert get_response(client, "/test") == [3, 1] |
|||
assert dependency_factory.activation_times == 3 |
|||
if supports_teardown: |
|||
assert dependency_factory.deactivation_times == 2 |
|||
else: |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
assert dependency_factory.activation_times == 3 |
|||
if supports_teardown: |
|||
assert dependency_factory.deactivation_times == 3 |
|||
else: |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
|
|||
@pytest.mark.parametrize("lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"]) |
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( |
|||
dependency_style: DependencyStyle, |
|||
is_websocket, |
|||
lifespan_style: Literal["lifespan_function", "lifespan_events"] |
|||
): |
|||
lifespan_started = False |
|||
lifespan_ended = False |
|||
if lifespan_style == "lifespan_generator": |
|||
@asynccontextmanager |
|||
async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, int], None]: |
|||
nonlocal lifespan_started |
|||
nonlocal lifespan_ended |
|||
lifespan_started = True |
|||
yield |
|||
lifespan_ended = True |
|||
|
|||
app = FastAPI(lifespan=lifespan) |
|||
elif lifespan_style == "events_decorator": |
|||
app = FastAPI() |
|||
with warnings.catch_warnings(action="ignore", category=DeprecationWarning): |
|||
@app.on_event("startup") |
|||
async def startup() -> None: |
|||
nonlocal lifespan_started |
|||
lifespan_started = True |
|||
|
|||
@app.on_event("shutdown") |
|||
async def shutdown() -> None: |
|||
nonlocal lifespan_ended |
|||
lifespan_ended = True |
|||
elif lifespan_style == "events_constructor": |
|||
async def startup() -> None: |
|||
nonlocal lifespan_started |
|||
lifespan_started = True |
|||
|
|||
async def shutdown() -> None: |
|||
nonlocal lifespan_ended |
|||
lifespan_ended = True |
|||
app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) |
|||
else: |
|||
assert_never(lifespan_style) |
|||
|
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan" |
|||
)], |
|||
expected_value=1 |
|||
) |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1, |
|||
urls_and_responses=[("/test", 1)] * 2, |
|||
is_websocket=is_websocket |
|||
) |
|||
assert lifespan_started and lifespan_ended |
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("depends_class", [Depends, Security]) |
|||
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( |
|||
depends_class, |
|||
is_websocket |
|||
): |
|||
async def sub_dependency() -> None: |
|||
pass |
|||
|
|||
async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: |
|||
yield |
|||
|
|||
app = FastAPI() |
|||
|
|||
with pytest.raises(DependencyScopeConflict): |
|||
create_endpoint_1_annotation( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")], |
|||
) |
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_dependencies_must_provide_correct_dependency_scope( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
is_websocket |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
with pytest.raises( |
|||
InvalidDependencyScope, |
|||
match=r'Dependency "value" of .* has an invalid scope: ' |
|||
r'"incorrect"' |
|||
): |
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[None, Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="incorrect", |
|||
use_cache=use_cache, |
|||
)] |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoints_report_incorrect_dependency_scope( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
is_websocket |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
# We intentionally change the dependency scope here to bypass the |
|||
# validation at the function level. |
|||
depends.dependency_scope = "asdad" |
|||
|
|||
with pytest.raises(InvalidDependencyScope): |
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, depends] |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoints_report_uninitialized_dependency( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
is_websocket |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, depends], |
|||
expected_value=1 |
|||
) |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
with TestClient(app) as client: |
|||
dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] |
|||
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {} |
|||
|
|||
try: |
|||
with pytest.raises(UninitializedLifespanDependency): |
|||
if is_websocket: |
|||
with client.websocket_connect("/test"): |
|||
pass |
|||
else: |
|||
client.post("/test") |
|||
finally: |
|||
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoints_report_uninitialized_internal_lifespan( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, depends], |
|||
expected_value=1 |
|||
) |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
with TestClient(app) as client: |
|||
internal_state = client.app_state["__fastapi__"] |
|||
del client.app_state["__fastapi__"] |
|||
|
|||
try: |
|||
with pytest.raises(UninitializedLifespanDependency): |
|||
if is_websocket: |
|||
with client.websocket_connect("/test"): |
|||
pass |
|||
else: |
|||
client.post("/test") |
|||
finally: |
|||
client.app_state["__fastapi__"] = internal_state |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_bad_lifespan_scoped_dependencies( |
|||
use_cache, |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
is_websocket |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style, should_error=True) |
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, depends], |
|||
expected_value=1 |
|||
) |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
with pytest.raises(IntentionallyBadDependency) as exception_info: |
|||
with TestClient(app): |
|||
pass |
|||
|
|||
assert exception_info.value.args == (1,) |
@ -0,0 +1,202 @@ |
|||
from enum import StrEnum, auto |
|||
from typing import Any, AsyncGenerator, Generator, TypeVar, Union, assert_never |
|||
|
|||
from fastapi import APIRouter, FastAPI, WebSocket |
|||
from starlette.testclient import TestClient |
|||
from starlette.websockets import WebSocketDisconnect |
|||
|
|||
T = TypeVar('T') |
|||
|
|||
|
|||
class DependencyStyle(StrEnum): |
|||
SYNC_FUNCTION = auto() |
|||
ASYNC_FUNCTION = auto() |
|||
SYNC_GENERATOR = auto() |
|||
ASYNC_GENERATOR = auto() |
|||
|
|||
|
|||
class IntentionallyBadDependency(Exception): |
|||
pass |
|||
|
|||
|
|||
class DependencyFactory: |
|||
def __init__( |
|||
self, |
|||
dependency_style: DependencyStyle, *, |
|||
should_error: bool = False, |
|||
value_offset: int = 0, |
|||
): |
|||
self.activation_times = 0 |
|||
self.deactivation_times = 0 |
|||
self.dependency_style = dependency_style |
|||
self._should_error = should_error |
|||
self._value_offset = value_offset |
|||
|
|||
def get_dependency(self): |
|||
if self.dependency_style == DependencyStyle.SYNC_FUNCTION: |
|||
return self._synchronous_function_dependency |
|||
|
|||
if self.dependency_style == DependencyStyle.SYNC_GENERATOR: |
|||
return self._synchronous_generator_dependency |
|||
|
|||
if self.dependency_style == DependencyStyle.ASYNC_FUNCTION: |
|||
return self._asynchronous_function_dependency |
|||
|
|||
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: |
|||
return self._asynchronous_generator_dependency |
|||
|
|||
assert_never(self.dependency_style) |
|||
|
|||
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise IntentionallyBadDependency(self.activation_times) |
|||
|
|||
yield self.activation_times + self._value_offset |
|||
self.deactivation_times += 1 |
|||
|
|||
def _synchronous_generator_dependency(self) -> Generator[T, None, None]: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise IntentionallyBadDependency(self.activation_times) |
|||
|
|||
yield self.activation_times + self._value_offset |
|||
self.deactivation_times += 1 |
|||
|
|||
async def _asynchronous_function_dependency(self) -> T: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise IntentionallyBadDependency(self.activation_times) |
|||
|
|||
return self.activation_times + self._value_offset |
|||
|
|||
def _synchronous_function_dependency(self) -> T: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise IntentionallyBadDependency(self.activation_times) |
|||
|
|||
return self.activation_times + self._value_offset |
|||
|
|||
|
|||
def use_endpoint(client: TestClient, url: str) -> Any: |
|||
response = client.post(url) |
|||
response.raise_for_status() |
|||
return response.json() |
|||
|
|||
|
|||
def use_websocket(client: TestClient, url: str) -> Any: |
|||
with client.websocket_connect(url) as connection: |
|||
return connection.receive_json() |
|||
|
|||
|
|||
def create_endpoint_0_annotations( |
|||
*, |
|||
router: Union[APIRouter, FastAPI], |
|||
path: str, |
|||
is_websocket: bool, |
|||
) -> None: |
|||
if is_websocket: |
|||
@router.websocket(path) |
|||
async def endpoint(websocket: WebSocket) -> None: |
|||
await websocket.accept() |
|||
try: |
|||
await websocket.send_json(None) |
|||
except WebSocketDisconnect: |
|||
pass |
|||
else: |
|||
@router.post(path) |
|||
async def endpoint() -> None: |
|||
return None |
|||
|
|||
|
|||
def create_endpoint_1_annotation( |
|||
*, |
|||
router: Union[APIRouter, FastAPI], |
|||
path: str, |
|||
is_websocket: bool, |
|||
annotation: Any, |
|||
expected_value: Any = None |
|||
) -> None: |
|||
if is_websocket: |
|||
@router.websocket(path) |
|||
async def endpoint( |
|||
websocket: WebSocket, |
|||
value: annotation |
|||
) -> None: |
|||
if expected_value is not None: |
|||
assert value == expected_value |
|||
|
|||
await websocket.accept() |
|||
try: |
|||
await websocket.send_json(value) |
|||
except WebSocketDisconnect: |
|||
pass |
|||
else: |
|||
@router.post(path) |
|||
async def endpoint( |
|||
value: annotation |
|||
) -> None: |
|||
if expected_value is not None: |
|||
assert value == expected_value |
|||
|
|||
return value |
|||
|
|||
def create_endpoint_2_annotations( |
|||
*, |
|||
router: Union[APIRouter, FastAPI], |
|||
path: str, |
|||
is_websocket: bool, |
|||
annotation1: Any, |
|||
annotation2: Any, |
|||
) -> None: |
|||
if is_websocket: |
|||
@router.websocket(path) |
|||
async def endpoint( |
|||
websocket: WebSocket, |
|||
value1: annotation1, |
|||
value2: annotation2, |
|||
) -> None: |
|||
await websocket.accept() |
|||
try: |
|||
await websocket.send_json([value1, value2]) |
|||
except WebSocketDisconnect: |
|||
await websocket.close() |
|||
else: |
|||
@router.post(path) |
|||
async def endpoint( |
|||
value1: annotation1, |
|||
value2: annotation2, |
|||
) -> list[Any]: |
|||
return [value1, value2] |
|||
|
|||
|
|||
def create_endpoint_3_annotations( |
|||
*, |
|||
router: Union[APIRouter, FastAPI], |
|||
path: str, |
|||
is_websocket: bool, |
|||
annotation1: Any, |
|||
annotation2: Any, |
|||
annotation3: Any |
|||
) -> None: |
|||
if is_websocket: |
|||
@router.websocket(path) |
|||
async def endpoint( |
|||
websocket: WebSocket, |
|||
value1: annotation1, |
|||
value2: annotation2, |
|||
value3: annotation3 |
|||
) -> None: |
|||
await websocket.accept() |
|||
try: |
|||
await websocket.send_json([value1, value2, value3]) |
|||
except WebSocketDisconnect: |
|||
await websocket.close() |
|||
else: |
|||
@router.post(path) |
|||
async def endpoint( |
|||
value1: annotation1, |
|||
value2: annotation2, |
|||
value3: annotation3 |
|||
) -> list[Any]: |
|||
return [value1, value2, value3] |
Loading…
Reference in new issue