11 changed files with 1208 additions and 95 deletions
@ -0,0 +1,46 @@ |
|||
from __future__ import annotations |
|||
|
|||
from contextlib import AsyncExitStack |
|||
from typing import TYPE_CHECKING, Any, Callable, Dict, List |
|||
|
|||
from fastapi.dependencies.models import LifespanDependant, LifespanDependantCacheKey |
|||
from fastapi.dependencies.utils import solve_lifespan_dependant |
|||
from fastapi.routing import APIRoute |
|||
|
|||
if TYPE_CHECKING: |
|||
from fastapi import FastAPI |
|||
|
|||
|
|||
def _get_lifespan_dependants(app: FastAPI) -> List[LifespanDependant]: |
|||
lifespan_dependants_cache: Dict[LifespanDependantCacheKey, LifespanDependant] = {} |
|||
for route in app.router.routes: |
|||
if not isinstance(route, APIRoute): |
|||
continue |
|||
|
|||
for sub_dependant in route.lifespan_dependencies: |
|||
if sub_dependant.cache_key in lifespan_dependants_cache: |
|||
continue |
|||
|
|||
lifespan_dependants_cache[sub_dependant.cache_key] = sub_dependant |
|||
|
|||
return list(lifespan_dependants_cache.values()) |
|||
|
|||
|
|||
async def resolve_lifespan_dependants( |
|||
*, |
|||
app: FastAPI, |
|||
async_exit_stack: AsyncExitStack |
|||
) -> Dict[LifespanDependantCacheKey, Callable[..., Any]]: |
|||
lifespan_dependants = _get_lifespan_dependants(app) |
|||
dependency_cache: Dict[LifespanDependantCacheKey, Callable[..., Any]] = {} |
|||
for lifespan_dependant in lifespan_dependants: |
|||
solved_dependency = await solve_lifespan_dependant( |
|||
dependant=lifespan_dependant, |
|||
dependency_overrides_provider=app, |
|||
dependency_cache=dependency_cache, |
|||
async_exit_stack=async_exit_stack |
|||
) |
|||
|
|||
dependency_cache.update(solved_dependency.dependency_cache) |
|||
|
|||
return dependency_cache |
@ -0,0 +1,703 @@ |
|||
from enum import StrEnum, auto |
|||
from typing import Any, AsyncGenerator, List, Tuple, TypeVar |
|||
|
|||
import pytest |
|||
from fastapi import ( |
|||
APIRouter, |
|||
BackgroundTasks, |
|||
Body, |
|||
Cookie, |
|||
Depends, |
|||
FastAPI, |
|||
File, |
|||
Form, |
|||
Header, |
|||
Path, |
|||
Query, |
|||
) |
|||
from fastapi.exceptions import FastAPIError |
|||
from fastapi.params import Security |
|||
from fastapi.security import SecurityScopes |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Annotated, Generator, Literal, assert_never |
|||
|
|||
T = TypeVar('T') |
|||
|
|||
|
|||
class DependencyStyle(StrEnum): |
|||
SYNC_FUNCTION = auto() |
|||
ASYNC_FUNCTION = auto() |
|||
SYNC_GENERATOR = auto() |
|||
ASYNC_GENERATOR = auto() |
|||
|
|||
|
|||
class DependencyFactory: |
|||
def __init__( |
|||
self, |
|||
dependency_style: DependencyStyle, *, |
|||
should_error: bool = False |
|||
): |
|||
self.activation_times = 0 |
|||
self.deactivation_times = 0 |
|||
self.dependency_style = dependency_style |
|||
self._should_error = should_error |
|||
|
|||
def get_dependency(self): |
|||
if self.dependency_style == DependencyStyle.SYNC_FUNCTION: |
|||
return self._synchronous_function_dependency |
|||
|
|||
if self.dependency_style == DependencyStyle.SYNC_GENERATOR: |
|||
return self._synchronous_generator_dependency |
|||
|
|||
if self.dependency_style == DependencyStyle.ASYNC_FUNCTION: |
|||
return self._asynchronous_function_dependency |
|||
|
|||
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: |
|||
return self._asynchronous_generator_dependency |
|||
|
|||
assert_never(self.dependency_style) |
|||
|
|||
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise ValueError(self.activation_times) |
|||
|
|||
yield self.activation_times |
|||
self.deactivation_times += 1 |
|||
|
|||
def _synchronous_generator_dependency(self) -> Generator[T, None, None]: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise ValueError(self.activation_times) |
|||
|
|||
yield self.activation_times |
|||
self.deactivation_times += 1 |
|||
|
|||
async def _asynchronous_function_dependency(self) -> T: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise ValueError(self.activation_times) |
|||
|
|||
return self.activation_times |
|||
|
|||
def _synchronous_function_dependency(self) -> T: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise ValueError(self.activation_times) |
|||
|
|||
return self.activation_times |
|||
|
|||
|
|||
def _expect_correct_amount_of_dependency_activations( |
|||
*, |
|||
app: FastAPI, |
|||
dependency_factory: DependencyFactory, |
|||
urls_and_responses: List[Tuple[str, Any]], |
|||
expected_activation_times: int |
|||
) -> None: |
|||
assert dependency_factory.activation_times == 0 |
|||
assert dependency_factory.deactivation_times == 0 |
|||
with TestClient(app) as client: |
|||
assert dependency_factory.activation_times == expected_activation_times |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
for url, expected_response in urls_and_responses: |
|||
response = client.post(url) |
|||
response.raise_for_status() |
|||
assert response.json() == expected_response |
|||
|
|||
assert dependency_factory.activation_times == expected_activation_times |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
assert dependency_factory.activation_times == expected_activation_times |
|||
if dependency_factory.dependency_style not in ( |
|||
DependencyStyle.SYNC_FUNCTION, |
|||
DependencyStyle.ASYNC_FUNCTION |
|||
): |
|||
assert dependency_factory.deactivation_times == expected_activation_times |
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, use_cache): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[None, Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
)] |
|||
) -> None: |
|||
assert dependency == 1 |
|||
return dependency |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
urls_and_responses=[("/test", 1)] * 2, |
|||
expected_activation_times=1 |
|||
) |
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_router_dependencies( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
if routing_style == "app": |
|||
app = FastAPI(dependencies=[depends]) |
|||
|
|||
@app.post("/test") |
|||
async def endpoint() -> None: |
|||
return None |
|||
else: |
|||
app = FastAPI() |
|||
router = APIRouter(dependencies=[depends]) |
|||
|
|||
@router.post("/test") |
|||
async def endpoint() -> None: |
|||
return None |
|||
|
|||
app.include_router(router) |
|||
|
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
urls_and_responses=[("/test", None)] * 2, |
|||
expected_activation_times=1 |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) |
|||
def test_dependency_cache_in_same_dependency( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
main_dependency_scope: Literal["endpoint", "lifespan"] |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def dependency( |
|||
sub_dependency1: Annotated[int, depends], |
|||
sub_dependency2: Annotated[int, depends], |
|||
) -> List[int]: |
|||
return [sub_dependency1, sub_dependency2] |
|||
|
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[List[int], Depends( |
|||
dependency, |
|||
use_cache=use_cache, |
|||
dependency_scope=main_dependency_scope, |
|||
)] |
|||
) -> List[int]: |
|||
return dependency |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
if use_cache: |
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [1, 1]), |
|||
("/test", [1, 1]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1 |
|||
) |
|||
else: |
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [1, 2]), |
|||
("/test", [1, 2]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=2 |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_dependency_cache_in_same_endpoint( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
|||
return dependency3 |
|||
|
|||
@router.post("/test1") |
|||
async def endpoint( |
|||
dependency1: Annotated[int, depends], |
|||
dependency2: Annotated[int, depends], |
|||
dependency3: Annotated[int, Depends(endpoint_dependency)] |
|||
) -> List[int]: |
|||
return [dependency1, dependency2, dependency3] |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
if use_cache: |
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [1, 1, 1]), |
|||
("/test1", [1, 1, 1]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1 |
|||
) |
|||
else: |
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [1, 2, 3]), |
|||
("/test1", [1, 2, 3]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=3 |
|||
) |
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_dependency_cache_in_different_endpoints( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
|||
return dependency3 |
|||
|
|||
@router.post("/test1") |
|||
async def endpoint( |
|||
dependency1: Annotated[int, depends], |
|||
dependency2: Annotated[int, depends], |
|||
dependency3: Annotated[int, Depends(endpoint_dependency)] |
|||
) -> List[int]: |
|||
return [dependency1, dependency2, dependency3] |
|||
|
|||
@router.post("/test2") |
|||
async def endpoint2( |
|||
dependency1: Annotated[int, depends], |
|||
dependency2: Annotated[int, depends], |
|||
dependency3: Annotated[int, Depends(endpoint_dependency)] |
|||
) -> List[int]: |
|||
return [dependency1, dependency2, dependency3] |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
if use_cache: |
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [1, 1, 1]), |
|||
("/test2", [1, 1, 1]), |
|||
("/test1", [1, 1, 1]), |
|||
("/test2", [1, 1, 1]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1 |
|||
) |
|||
else: |
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [1, 2, 3]), |
|||
("/test2", [4, 5, 3]), |
|||
("/test1", [1, 2, 3]), |
|||
("/test2", [4, 5, 3]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=5 |
|||
) |
|||
|
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_no_cached_dependency( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=False |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[int, depends], |
|||
) -> int: |
|||
return dependency |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
urls_and_responses=[("/test", 1)] * 2, |
|||
expected_activation_times=1 |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("annotation", [ |
|||
Annotated[str, Path()], |
|||
Annotated[str, Body()], |
|||
Annotated[str, Query()], |
|||
Annotated[str, Header()], |
|||
SecurityScopes, |
|||
Annotated[str, Cookie()], |
|||
Annotated[str, Form()], |
|||
Annotated[str, File()], |
|||
BackgroundTasks, |
|||
]) |
|||
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( |
|||
annotation |
|||
): |
|||
async def dependency_func(param: annotation) -> None: |
|||
yield |
|||
|
|||
app = FastAPI() |
|||
|
|||
with pytest.raises(FastAPIError): |
|||
@app.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[ |
|||
None, Depends(dependency_func, dependency_scope="lifespan")] |
|||
) -> None: |
|||
return |
|||
|
|||
|
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( |
|||
dependency_style: DependencyStyle |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
async def lifespan_scoped_dependency( |
|||
param: Annotated[int, Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan" |
|||
)] |
|||
) -> AsyncGenerator[int, None]: |
|||
yield param |
|||
|
|||
app = FastAPI() |
|||
|
|||
@app.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[int, Depends( |
|||
lifespan_scoped_dependency, |
|||
dependency_scope="lifespan" |
|||
)] |
|||
) -> int: |
|||
return dependency |
|||
|
|||
_expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1, |
|||
urls_and_responses=[("/test", 1)] * 2 |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("depends_class", [Depends, Security]) |
|||
@pytest.mark.parametrize("route_type", [FastAPI.post, FastAPI.websocket], ids=[ |
|||
"websocket", "endpoint" |
|||
]) |
|||
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( |
|||
depends_class, |
|||
route_type |
|||
): |
|||
async def sub_dependency() -> None: |
|||
pass |
|||
|
|||
async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None: |
|||
yield |
|||
|
|||
app = FastAPI() |
|||
route_decorator = route_type(app, "/test") |
|||
|
|||
with pytest.raises(FastAPIError): |
|||
@route_decorator |
|||
async def endpoint(x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")] |
|||
) -> None: |
|||
return |
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_dependencies_must_provide_correct_dependency_scope( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
with pytest.raises(FastAPIError): |
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[None, Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="incorrect", |
|||
use_cache=use_cache, |
|||
)] |
|||
) -> None: |
|||
assert dependency == 1 |
|||
return dependency |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoints_report_incorrect_dependency_scope( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
# We intentionally change the dependency scope here to bypass the |
|||
# validation at the function level. |
|||
depends.dependency_scope = "asdad" |
|||
|
|||
with pytest.raises(FastAPIError): |
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[int, depends] |
|||
) -> int: |
|||
assert dependency == 1 |
|||
return dependency |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoints_report_uninitialized_dependency( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[int, depends] |
|||
) -> int: |
|||
assert dependency == 1 |
|||
return dependency |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
with TestClient(app) as client: |
|||
dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] |
|||
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {} |
|||
|
|||
try: |
|||
with pytest.raises(FastAPIError): |
|||
client.post("/test") |
|||
finally: |
|||
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoints_report_uninitialized_internal_lifespan( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache |
|||
): |
|||
dependency_factory= DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[int, depends] |
|||
) -> int: |
|||
assert dependency == 1 |
|||
return dependency |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
with TestClient(app) as client: |
|||
internal_state = client.app_state["__fastapi__"] |
|||
del client.app_state["__fastapi__"] |
|||
|
|||
try: |
|||
with pytest.raises(FastAPIError): |
|||
client.post("/test") |
|||
finally: |
|||
client.app_state["__fastapi__"] = internal_state |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: DependencyStyle, routing_style): |
|||
dependency_factory= DependencyFactory(dependency_style, should_error=True) |
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
@router.post("/test") |
|||
async def endpoint( |
|||
dependency: Annotated[int, depends] |
|||
) -> int: |
|||
assert dependency == 1 |
|||
return dependency |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
with pytest.raises(ValueError) as exception_info: |
|||
with TestClient(app): |
|||
pass |
|||
|
|||
assert exception_info.value.args == (1,) |
|||
|
|||
|
|||
# TODO: Add tests for dependency_overrides |
|||
# TODO: Add a websocket equivalent to all tests |
Loading…
Reference in new issue