Browse Source

Added tests for dependency overrides and websockets. Fixed bugs related to the deprecated startup and shutdown events. Fixed bugs related to dependency duplcatation within the same router scope. Made more specific dependency related exceptions. Fixed some linting and mypy related issues.

pull/12529/head
Nir Schulman 9 months ago
parent
commit
c4860bfb7c
  1. 11
      fastapi/applications.py
  2. 26
      fastapi/dependencies/models.py
  3. 164
      fastapi/dependencies/utils.py
  4. 16
      fastapi/exceptions.py
  5. 4
      fastapi/lifespan.py
  6. 14
      fastapi/routing.py
  7. 703
      tests/test_lifespan_scoped_dependencies.py
  8. 0
      tests/test_lifespan_scoped_dependencies/__init__.py
  9. 634
      tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py
  10. 854
      tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py
  11. 202
      tests/test_lifespan_scoped_dependencies/testing_utilities.py

11
fastapi/applications.py

@ -946,8 +946,13 @@ class FastAPI(Starlette):
# Since we always use a lifespan, starlette will no longer run event
# handlers which are defined in the scope of the application.
# We therefore need to call them ourselves.
self._on_startup = on_startup or []
self._on_shutdown = on_shutdown or []
if on_startup is None:
on_startup = []
if on_shutdown is None:
on_shutdown = []
self._on_startup = list(on_startup)
self._on_shutdown = list(on_shutdown)
self.router: routing.APIRouter = routing.APIRouter(
routes=routes,
@ -982,7 +987,7 @@ class FastAPI(Starlette):
self.setup()
@asynccontextmanager
async def _internal_lifespan(self) -> AsyncGenerator[dict[str, Any], None]:
async def _internal_lifespan(self) -> AsyncGenerator[Dict[str, Any], None]:
async with AsyncExitStack() as exit_stack:
lifespan_scoped_dependencies = await resolve_lifespan_dependants(
app=self,

26
fastapi/dependencies/models.py

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase
@ -12,22 +12,28 @@ class SecurityRequirement:
scopes: Optional[Sequence[str]] = None
LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], str], Callable[..., Any]]
LifespanDependantCacheKey: TypeAlias = Union[Tuple[Callable[..., Any], Union[str, int]], Callable[..., Any]]
@dataclass
class LifespanDependant:
call: Callable[..., Any]
caller: Callable[..., Any]
dependencies: List["LifespanDependant"] = field(default_factory=list)
name: Optional[str] = None
call: Optional[Callable[..., Any]] = None
use_cache: bool = True
index: Optional[int] = None
cache_key: LifespanDependantCacheKey = field(init=False)
def __post_init__(self) -> None:
if self.use_cache:
self.cache_key = self.call
else:
elif self.name is not None:
self.cache_key = (self.caller, self.name)
else:
assert self.index is not None, (
"Lifespan dependency must have an associated name or index."
)
self.cache_key = (self.caller, self.index)
EndpointDependantCacheKey: TypeAlias = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
@ -39,6 +45,7 @@ class EndpointDependant:
name: Optional[str] = None
call: Optional[Callable[..., Any]] = None
use_cache: bool = True
index: Optional[int] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(
init=False)
path_params: List[ModelField] = field(default_factory=list)
@ -62,7 +69,16 @@ class EndpointDependant:
# Kept for backwards compatibility
@property
def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]:
return tuple(self.endpoint_dependencies + self.lifespan_dependencies)
lifespan_dependencies = cast(
List[Union[EndpointDependant, LifespanDependant]],
self.lifespan_dependencies
)
endpoint_dependencies = cast(
List[Union[EndpointDependant, LifespanDependant]],
self.endpoint_dependencies
)
return tuple(lifespan_dependencies + endpoint_dependencies)
# Kept for backwards compatibility
Dependant = EndpointDependant

164
fastapi/dependencies/utils.py

@ -54,11 +54,16 @@ from fastapi.concurrency import (
from fastapi.dependencies.models import (
CacheKey,
EndpointDependant,
EndpointDependantCacheKey,
LifespanDependant,
LifespanDependantCacheKey,
SecurityRequirement,
)
from fastapi.exceptions import FastAPIError
from fastapi.exceptions import (
DependencyScopeConflict,
InvalidDependencyScope,
UninitializedLifespanDependency,
)
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
@ -78,7 +83,7 @@ from starlette.datastructures import (
from starlette.requests import HTTPConnection, Request
from starlette.responses import Response
from starlette.websockets import WebSocket
from typing_extensions import Annotated, get_args, get_origin
from typing_extensions import Annotated, assert_never, get_args, get_origin
multipart_not_installed_error = (
'Form data requires "python-multipart" to be installed. \n'
@ -137,7 +142,8 @@ def get_parameterless_sub_dependant(
*,
depends: params.Depends,
path: str,
caller: Callable[..., Any]
caller: Callable[..., Any],
index: int
) -> Union[EndpointDependant, LifespanDependant]:
assert callable(
depends.dependency
@ -146,7 +152,8 @@ def get_parameterless_sub_dependant(
depends=depends,
dependency=depends.dependency,
path=path,
caller=caller
caller=caller,
index=index
)
@ -158,13 +165,15 @@ def get_sub_dependant(
caller: Callable[..., Any],
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
index: Optional[int] = None,
) -> Union[EndpointDependant, LifespanDependant]:
if depends.dependency_scope == "lifespan":
return get_lifespan_dependant(
caller=caller,
call=depends.dependency,
call=dependency,
name=name,
use_cache=depends.use_cache
use_cache=depends.use_cache,
index=index
)
elif depends.dependency_scope == "endpoint":
security_requirement = None
@ -185,14 +194,15 @@ def get_sub_dependant(
name=name,
security_scopes=security_scopes,
use_cache=depends.use_cache,
index=index
)
if security_requirement:
sub_dependant.security_requirements.append(security_requirement)
return sub_dependant
else:
raise ValueError(
f"Dependency {name} of {caller} has an invalid "
f"sub-dependency scope: {depends.dependency_scope}"
raise InvalidDependencyScope(
f"Dependency \"{name}\" of {caller} has an invalid "
f"scope: \"{depends.dependency_scope}\""
)
@ -292,6 +302,7 @@ def get_lifespan_dependant(
call: Callable[..., Any],
name: Optional[str] = None,
use_cache: bool = True,
index: Optional[int] = None
) -> LifespanDependant:
dependency_signature = get_typed_signature(call)
signature_params = dependency_signature.parameters
@ -299,7 +310,8 @@ def get_lifespan_dependant(
call=call,
name=name,
use_cache=use_cache,
caller=caller
caller=caller,
index=index
)
for param_name, param in signature_params.items():
param_details = analyze_param(
@ -309,17 +321,23 @@ def get_lifespan_dependant(
is_path_param=False,
)
if param_details.depends is None:
raise FastAPIError(
f"Lifespan dependency {dependant.name} was defined with an "
f"invalid argument {param_name}. Lifespan dependencies may "
f"only use other lifespan dependencies as arguments.")
raise DependencyScopeConflict(
f"Lifespan scoped dependency \"{dependant.name}\" was defined "
f"with an invalid argument: \"{param_name}\" which is "
f"\"endpoint\" scoped. Lifespan scoped dependencies may only "
f"use lifespan scoped sub-dependencies.")
if param_details.depends.dependency_scope != "lifespan":
raise FastAPIError(
"Lifespan dependency may not use "
"sub-dependencies of other scopes."
raise DependencyScopeConflict(
f"Lifespan scoped dependency {dependant.name} was defined with the "
f"sub-dependency \"{param_name}\" which is "
f"\"{param_details.depends.dependency_scope}\" scoped. "
f"Lifespan scoped dependencies may only use lifespan scoped "
f"sub-dependencies."
)
assert param_details.depends.dependency is not None
sub_dependant = get_lifespan_dependant(
name=param_name,
call=param_details.depends.dependency,
@ -339,6 +357,7 @@ def get_endpoint_dependant(
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
index: Optional[int] = None
) -> EndpointDependant:
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
@ -349,6 +368,7 @@ def get_endpoint_dependant(
path=path,
security_scopes=security_scopes,
use_cache=use_cache,
index=index
)
for param_name, param in signature_params.items():
is_path_param = param_name in path_param_names
@ -359,28 +379,19 @@ def get_endpoint_dependant(
is_path_param=is_path_param,
)
if param_details.depends is not None:
if param_details.depends.dependency_scope == "endpoint":
sub_dependant = get_param_sub_dependant(
param_name=param_name,
depends=param_details.depends,
path=path,
security_scopes=security_scopes,
caller=call,
)
sub_dependant = get_param_sub_dependant(
param_name=param_name,
depends=param_details.depends,
path=path,
security_scopes=security_scopes,
caller=call,
)
if isinstance(sub_dependant, EndpointDependant):
dependant.endpoint_dependencies.append(sub_dependant)
elif param_details.depends.dependency_scope == "lifespan":
sub_dependant = get_lifespan_dependant(
caller=call,
call=param_details.depends.dependency,
name=param_name,
use_cache=param_details.depends.use_cache,
)
elif isinstance(sub_dependant, LifespanDependant):
dependant.lifespan_dependencies.append(sub_dependant)
else:
raise FastAPIError(
f"Dependency \"{param_name}\" of `{call}` has an invalid "
f"sub-dependency scope: \"{param_details.depends.dependency_scope}\""
)
assert_never(sub_dependant)
continue
if add_non_field_param_to_dependency(
param_name=param_name,
@ -652,7 +663,7 @@ async def solve_generator(
@dataclass
class SolvedLifespanDependant:
value: Any
dependency_cache: Dict[Callable[..., Any], Any]
dependency_cache: Dict[LifespanDependantCacheKey, Any]
async def solve_lifespan_dependant(
@ -669,35 +680,33 @@ async def solve_lifespan_dependant(
dependency_cache=dependency_cache,
)
dependency_arguments: Dict[str, Any] = {}
sub_dependant: LifespanDependant
for sub_dependant in dependant.dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast(
Callable[..., Any], sub_dependant.cache_key
call = dependant.call
dependant_to_solve = dependant
if (
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
):
call = getattr(
dependency_overrides_provider,
"dependency_overrides",
{}
).get(dependant.call, dependant.call)
dependant_to_solve = get_lifespan_dependant(
caller=dependant.caller,
call=call,
name=dependant.name,
use_cache=dependant.use_cache,
index=dependant.index,
)
dependency_arguments: Dict[str, Any] = {}
for sub_dependant in dependant_to_solve.dependencies:
assert sub_dependant.name, (
"Lifespan scoped dependencies should not be able to have "
"subdependencies with no name"
)
sub_dependant_to_solve = sub_dependant
if (
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
):
original_call = sub_dependant.call
call = getattr(
dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call)
sub_dependant_to_solve = get_lifespan_dependant(
call=call,
name=sub_dependant.name,
caller=dependant.call
)
solved_sub_dependant = await solve_lifespan_dependant(
dependant=sub_dependant_to_solve,
dependant=sub_dependant,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack,
@ -705,16 +714,16 @@ async def solve_lifespan_dependant(
dependency_cache.update(solved_sub_dependant.dependency_cache)
dependency_arguments[sub_dependant.name] = solved_sub_dependant.value
if is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call):
if is_gen_callable(call) or is_async_gen_callable(call):
value = await solve_generator(
call=dependant.call,
call=call,
stack=async_exit_stack,
sub_values=dependency_arguments
)
elif is_coroutine_callable(dependant.call):
value = await dependant.call(**dependency_arguments)
elif is_coroutine_callable(call):
value = await call(**dependency_arguments)
else:
value = await run_in_threadpool(dependant.call, **dependency_arguments)
value = await run_in_threadpool(call, **dependency_arguments)
if dependant.cache_key not in dependency_cache:
dependency_cache[dependant.cache_key] = value
@ -731,7 +740,7 @@ class SolvedDependency:
errors: List[Any]
background_tasks: Optional[StarletteBackgroundTasks]
response: Response
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
dependency_cache: Dict[EndpointDependantCacheKey, Any]
async def solve_dependencies(
@ -742,33 +751,34 @@ async def solve_dependencies(
background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
dependency_cache: Optional[Dict[EndpointDependantCacheKey, Any]] = None,
async_exit_stack: AsyncExitStack,
embed_body_fields: bool,
) -> SolvedDependency:
values: Dict[str, Any] = {}
errors: List[Any] = []
for sub_dependant in dependant.lifespan_dependencies:
if sub_dependant.name is None:
for lifespan_sub_dependant in dependant.lifespan_dependencies:
if lifespan_sub_dependant.name is None:
continue
try:
lifespan_scoped_dependencies = request.state.__fastapi__[
"lifespan_scoped_dependencies"]
except AttributeError as e:
raise FastAPIError(
"FastAPI's internal lifespan was not initialized"
except (AttributeError, KeyError) as e:
raise UninitializedLifespanDependency(
"FastAPI's internal lifespan was not initialized correctly."
) from e
try:
value = lifespan_scoped_dependencies[sub_dependant.cache_key]
value = lifespan_scoped_dependencies[lifespan_sub_dependant.cache_key]
except KeyError as e:
raise FastAPIError(
f"Dependency {sub_dependant.name} of {dependant.call} "
f"was not initialized."
raise UninitializedLifespanDependency(
f"Dependency \"{lifespan_sub_dependant.name}\" of "
f"`{dependant.call}` was not initialized correctly."
) from e
values[sub_dependant.name] = value
values[lifespan_sub_dependant.name] = value
if response is None:
response = Response()

16
fastapi/exceptions.py

@ -146,6 +146,22 @@ class FastAPIError(RuntimeError):
"""
class DependencyError(FastAPIError):
pass
class InvalidDependencyScope(DependencyError):
pass
class DependencyScopeConflict(DependencyError):
pass
class UninitializedLifespanDependency(DependencyError):
pass
class ValidationException(Exception):
def __init__(self, errors: Sequence[Any]) -> None:
self._errors = errors

4
fastapi/lifespan.py

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List
from fastapi.dependencies.models import LifespanDependant, LifespanDependantCacheKey
from fastapi.dependencies.utils import solve_lifespan_dependant
from fastapi.routing import APIRoute
from fastapi.routing import APIRoute, APIWebSocketRoute
if TYPE_CHECKING:
from fastapi import FastAPI
@ -14,7 +14,7 @@ if TYPE_CHECKING:
def _get_lifespan_dependants(app: FastAPI) -> List[LifespanDependant]:
lifespan_dependants_cache: Dict[LifespanDependantCacheKey, LifespanDependant] = {}
for route in app.router.routes:
if not isinstance(route, APIRoute):
if not isinstance(route, (APIWebSocketRoute, APIRoute)):
continue
for sub_dependant in route.lifespan_dependencies:

14
fastapi/routing.py

@ -401,15 +401,18 @@ class APIWebSocketRoute(routing.WebSocketRoute):
self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]:
for i, depends in list(enumerate(self.dependencies))[::-1]:
sub_dependant = get_parameterless_sub_dependant(
depends=depends,
path=self.path_format,
caller=self
caller=self.__call__,
index=i
)
if depends.dependency_scope == "endpoint":
assert isinstance(sub_dependant, EndpointDependant)
self.dependant.endpoint_dependencies.insert(0, sub_dependant)
elif depends.dependency_scope == "lifespan":
assert isinstance(sub_dependant, LifespanDependant)
self.dependant.lifespan_dependencies.insert(0, sub_dependant)
else:
assert_never(depends.dependency_scope)
@ -564,15 +567,18 @@ class APIRoute(routing.Route):
assert callable(endpoint), "An endpoint must be a callable"
self.dependant = get_endpoint_dependant(path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]:
for i, depends in list(enumerate(self.dependencies))[::-1]:
sub_dependant = get_parameterless_sub_dependant(
depends=depends,
path=self.path_format,
caller=self.__call__
caller=self.__call__,
index=i
)
if depends.dependency_scope == "endpoint":
assert isinstance(sub_dependant, EndpointDependant)
self.dependant.endpoint_dependencies.insert(0, sub_dependant)
elif depends.dependency_scope == "lifespan":
assert isinstance(sub_dependant, LifespanDependant)
self.dependant.lifespan_dependencies.insert(0, sub_dependant)
else:
assert_never(depends.dependency_scope)

703
tests/test_lifespan_scoped_dependencies.py

@ -1,703 +0,0 @@
from enum import StrEnum, auto
from typing import Any, AsyncGenerator, List, Tuple, TypeVar
import pytest
from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Cookie,
Depends,
FastAPI,
File,
Form,
Header,
Path,
Query,
)
from fastapi.exceptions import FastAPIError
from fastapi.params import Security
from fastapi.security import SecurityScopes
from starlette.testclient import TestClient
from typing_extensions import Annotated, Generator, Literal, assert_never
T = TypeVar('T')
class DependencyStyle(StrEnum):
SYNC_FUNCTION = auto()
ASYNC_FUNCTION = auto()
SYNC_GENERATOR = auto()
ASYNC_GENERATOR = auto()
class DependencyFactory:
def __init__(
self,
dependency_style: DependencyStyle, *,
should_error: bool = False
):
self.activation_times = 0
self.deactivation_times = 0
self.dependency_style = dependency_style
self._should_error = should_error
def get_dependency(self):
if self.dependency_style == DependencyStyle.SYNC_FUNCTION:
return self._synchronous_function_dependency
if self.dependency_style == DependencyStyle.SYNC_GENERATOR:
return self._synchronous_generator_dependency
if self.dependency_style == DependencyStyle.ASYNC_FUNCTION:
return self._asynchronous_function_dependency
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR:
return self._asynchronous_generator_dependency
assert_never(self.dependency_style)
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
yield self.activation_times
self.deactivation_times += 1
def _synchronous_generator_dependency(self) -> Generator[T, None, None]:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
yield self.activation_times
self.deactivation_times += 1
async def _asynchronous_function_dependency(self) -> T:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
return self.activation_times
def _synchronous_function_dependency(self) -> T:
self.activation_times += 1
if self._should_error:
raise ValueError(self.activation_times)
return self.activation_times
def _expect_correct_amount_of_dependency_activations(
*,
app: FastAPI,
dependency_factory: DependencyFactory,
urls_and_responses: List[Tuple[str, Any]],
expected_activation_times: int
) -> None:
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
with TestClient(app) as client:
assert dependency_factory.activation_times == expected_activation_times
assert dependency_factory.deactivation_times == 0
for url, expected_response in urls_and_responses:
response = client.post(url)
response.raise_for_status()
assert response.json() == expected_response
assert dependency_factory.activation_times == expected_activation_times
assert dependency_factory.deactivation_times == 0
assert dependency_factory.activation_times == expected_activation_times
if dependency_factory.dependency_style not in (
DependencyStyle.SYNC_FUNCTION,
DependencyStyle.ASYNC_FUNCTION
):
assert dependency_factory.deactivation_times == expected_activation_times
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoint_dependencies(dependency_style: DependencyStyle, routing_style, use_cache):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
@router.post("/test")
async def endpoint(
dependency: Annotated[None, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)]
) -> None:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_router_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
if routing_style == "app":
app = FastAPI(dependencies=[depends])
@app.post("/test")
async def endpoint() -> None:
return None
else:
app = FastAPI()
router = APIRouter(dependencies=[depends])
@router.post("/test")
async def endpoint() -> None:
return None
app.include_router(router)
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", None)] * 2,
expected_activation_times=1
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"])
def test_dependency_cache_in_same_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache,
main_dependency_scope: Literal["endpoint", "lifespan"]
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def dependency(
sub_dependency1: Annotated[int, depends],
sub_dependency2: Annotated[int, depends],
) -> List[int]:
return [sub_dependency1, sub_dependency2]
@router.post("/test")
async def endpoint(
dependency: Annotated[List[int], Depends(
dependency,
use_cache=use_cache,
dependency_scope=main_dependency_scope,
)]
) -> List[int]:
return dependency
if routing_style == "router":
app.include_router(router)
if use_cache:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 1]),
("/test", [1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1
)
else:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 2]),
("/test", [1, 2]),
],
dependency_factory=dependency_factory,
expected_activation_times=2
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_same_endpoint(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
@router.post("/test1")
async def endpoint(
dependency1: Annotated[int, depends],
dependency2: Annotated[int, depends],
dependency3: Annotated[int, Depends(endpoint_dependency)]
) -> List[int]:
return [dependency1, dependency2, dependency3]
if routing_style == "router":
app.include_router(router)
if use_cache:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 1, 1]),
("/test1", [1, 1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1
)
else:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 2, 3]),
("/test1", [1, 2, 3]),
],
dependency_factory=dependency_factory,
expected_activation_times=3
)
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_different_endpoints(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
@router.post("/test1")
async def endpoint(
dependency1: Annotated[int, depends],
dependency2: Annotated[int, depends],
dependency3: Annotated[int, Depends(endpoint_dependency)]
) -> List[int]:
return [dependency1, dependency2, dependency3]
@router.post("/test2")
async def endpoint2(
dependency1: Annotated[int, depends],
dependency2: Annotated[int, depends],
dependency3: Annotated[int, Depends(endpoint_dependency)]
) -> List[int]:
return [dependency1, dependency2, dependency3]
if routing_style == "router":
app.include_router(router)
if use_cache:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 1, 1]),
("/test2", [1, 1, 1]),
("/test1", [1, 1, 1]),
("/test2", [1, 1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1
)
else:
_expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 2, 3]),
("/test2", [4, 5, 3]),
("/test1", [1, 2, 3]),
("/test2", [4, 5, 3]),
],
dependency_factory=dependency_factory,
expected_activation_times=5
)
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_no_cached_dependency(
dependency_style: DependencyStyle,
routing_style,
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=False
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends],
) -> int:
return dependency
if routing_style == "router":
app.include_router(router)
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1
)
@pytest.mark.parametrize("annotation", [
Annotated[str, Path()],
Annotated[str, Body()],
Annotated[str, Query()],
Annotated[str, Header()],
SecurityScopes,
Annotated[str, Cookie()],
Annotated[str, Form()],
Annotated[str, File()],
BackgroundTasks,
])
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
annotation
):
async def dependency_func(param: annotation) -> None:
yield
app = FastAPI()
with pytest.raises(FastAPIError):
@app.post("/test")
async def endpoint(
dependency: Annotated[
None, Depends(dependency_func, dependency_scope="lifespan")]
) -> None:
return
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies(
dependency_style: DependencyStyle
):
dependency_factory = DependencyFactory(dependency_style)
async def lifespan_scoped_dependency(
param: Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)]
) -> AsyncGenerator[int, None]:
yield param
app = FastAPI()
@app.post("/test")
async def endpoint(
dependency: Annotated[int, Depends(
lifespan_scoped_dependency,
dependency_scope="lifespan"
)]
) -> int:
return dependency
_expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
expected_activation_times=1,
urls_and_responses=[("/test", 1)] * 2
)
@pytest.mark.parametrize("depends_class", [Depends, Security])
@pytest.mark.parametrize("route_type", [FastAPI.post, FastAPI.websocket], ids=[
"websocket", "endpoint"
])
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
depends_class,
route_type
):
async def sub_dependency() -> None:
pass
async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None:
yield
app = FastAPI()
route_decorator = route_type(app, "/test")
with pytest.raises(FastAPIError):
@route_decorator
async def endpoint(x: Annotated[None, Depends(dependency_func, dependency_scope="lifespan")]
) -> None:
return
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_dependencies_must_provide_correct_dependency_scope(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
with pytest.raises(FastAPIError):
@router.post("/test")
async def endpoint(
dependency: Annotated[None, Depends(
dependency_factory.get_dependency(),
dependency_scope="incorrect",
use_cache=use_cache,
)]
) -> None:
assert dependency == 1
return dependency
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_incorrect_dependency_scope(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
# We intentionally change the dependency scope here to bypass the
# validation at the function level.
depends.dependency_scope = "asdad"
with pytest.raises(FastAPIError):
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
with TestClient(app) as client:
dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"]
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {}
try:
with pytest.raises(FastAPIError):
client.post("/test")
finally:
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_internal_lifespan(
dependency_style: DependencyStyle,
routing_style,
use_cache
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
with TestClient(app) as client:
internal_state = client.app_state["__fastapi__"]
del client.app_state["__fastapi__"]
try:
with pytest.raises(FastAPIError):
client.post("/test")
finally:
client.app_state["__fastapi__"] = internal_state
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_bad_lifespan_scoped_dependencies(use_cache, dependency_style: DependencyStyle, routing_style):
dependency_factory= DependencyFactory(dependency_style, should_error=True)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
@router.post("/test")
async def endpoint(
dependency: Annotated[int, depends]
) -> int:
assert dependency == 1
return dependency
if routing_style == "router_endpoint":
app.include_router(router)
with pytest.raises(ValueError) as exception_info:
with TestClient(app):
pass
assert exception_info.value.args == (1,)
# TODO: Add tests for dependency_overrides
# TODO: Add a websocket equivalent to all tests

0
tests/test_lifespan_scoped_dependencies/__init__.py

634
tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py

@ -0,0 +1,634 @@
from typing import Any, AsyncGenerator, List, Tuple
import pytest
from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Cookie,
Depends,
FastAPI,
File,
Form,
Header,
Path,
Query,
)
from fastapi.exceptions import DependencyScopeConflict
from fastapi.params import Security
from fastapi.security import SecurityScopes
from fastapi.testclient import TestClient
from typing_extensions import Annotated, Literal
from tests.test_lifespan_scoped_dependencies.testing_utilities import (
DependencyFactory,
DependencyStyle,
IntentionallyBadDependency,
create_endpoint_0_annotations,
create_endpoint_1_annotation,
create_endpoint_3_annotations,
use_endpoint,
use_websocket,
)
def expect_correct_amount_of_dependency_activations(
*,
app: FastAPI,
dependency_factory: DependencyFactory,
override_dependency_factory: DependencyFactory,
urls_and_responses: List[Tuple[str, Any]],
expected_activation_times: int,
is_websocket: bool
) -> None:
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
assert override_dependency_factory.activation_times == 0
assert override_dependency_factory.deactivation_times == 0
with TestClient(app) as client:
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
assert override_dependency_factory.activation_times == expected_activation_times
assert override_dependency_factory.deactivation_times == 0
for url, expected_response in urls_and_responses:
if is_websocket:
response = use_websocket(client, url)
else:
response = use_endpoint(client, url)
assert response == expected_response
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
assert override_dependency_factory.activation_times == expected_activation_times
assert override_dependency_factory.deactivation_times == 0
assert dependency_factory.activation_times == 0
assert override_dependency_factory.activation_times == expected_activation_times
if dependency_factory.dependency_style not in (
DependencyStyle.SYNC_FUNCTION,
DependencyStyle.ASYNC_FUNCTION
):
assert dependency_factory.deactivation_times == 0
assert override_dependency_factory.deactivation_times == expected_activation_times
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoint_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(
dependency_style,
value_offset=10
)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
None,
Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
],
expected_value=11
)
if routing_style == "router_endpoint":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency()
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
urls_and_responses=[("/test", 11)] * 2,
expected_activation_times=1,
is_websocket=is_websocket
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_duplication", [1, 2])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_router_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache,
dependency_duplication,
is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(
dependency_style,
value_offset=10
)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
if routing_style == "app":
app = FastAPI(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations(
router=app,
path="/test",
is_websocket=is_websocket
)
else:
app = FastAPI()
router = APIRouter(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations(
router=router,
path="/test",
is_websocket=is_websocket
)
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency()
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
urls_and_responses=[("/test", None)] * 2,
expected_activation_times=1 if use_cache else dependency_duplication,
is_websocket=is_websocket
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"])
def test_dependency_cache_in_same_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache,
main_dependency_scope: Literal["endpoint", "lifespan"],
is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(
dependency_style,
value_offset=10
)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def dependency(
sub_dependency1: Annotated[int, depends],
sub_dependency2: Annotated[int, depends],
) -> List[int]:
return [sub_dependency1, sub_dependency2]
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[List[int], Depends(
dependency,
use_cache=use_cache,
dependency_scope=main_dependency_scope,
)]
)
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[
dependency_factory.get_dependency()
] = override_dependency_factory.get_dependency()
if use_cache:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [11, 11]),
("/test", [11, 11]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket
)
else:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [11, 12]),
("/test", [11, 12]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=2,
is_websocket=is_websocket
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_same_endpoint(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(
dependency_style,
value_offset=10
)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
create_endpoint_3_annotations(
router=router,
path="/test1",
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)],
)
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[
dependency_factory.get_dependency()
] = override_dependency_factory.get_dependency()
if use_cache:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [11, 11, 11]),
("/test1", [11, 11, 11]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket
)
else:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [11, 12, 13]),
("/test1", [11, 12, 13]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=3,
is_websocket=is_websocket
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_different_endpoints(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(
dependency_style,
value_offset=10
)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
create_endpoint_3_annotations(
router=router,
path="/test1",
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)],
)
create_endpoint_3_annotations(
router=router,
path="/test2",
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)],
)
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[
dependency_factory.get_dependency()] = override_dependency_factory.get_dependency()
if use_cache:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [11, 11, 11]),
("/test2", [11, 11, 11]),
("/test1", [11, 11, 11]),
("/test2", [11, 11, 11]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket
)
else:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [11, 12, 13]),
("/test2", [14, 15, 13]),
("/test1", [11, 12, 13]),
("/test2", [14, 15, 13]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=5,
is_websocket=is_websocket
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_no_cached_dependency(
dependency_style: DependencyStyle,
routing_style,
is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(
dependency_style,
value_offset=10
)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=False
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
)
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[
dependency_factory.get_dependency()] = override_dependency_factory.get_dependency()
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
urls_and_responses=[("/test", 11)] * 2,
expected_activation_times=1,
is_websocket=is_websocket
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("annotation", [
Annotated[str, Path()],
Annotated[str, Body()],
Annotated[str, Query()],
Annotated[str, Header()],
SecurityScopes,
Annotated[str, Cookie()],
Annotated[str, Form()],
Annotated[str, File()],
BackgroundTasks,
])
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
annotation,
is_websocket
):
async def dependency_func() -> None:
yield
async def override_dependency_func(param: annotation) -> None:
yield
app = FastAPI()
app.dependency_overrides[dependency_func] = override_dependency_func
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[None,
Depends(dependency_func, dependency_scope="lifespan")
]
)
with pytest.raises(DependencyScopeConflict):
with TestClient(app):
pass
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies(
dependency_style: DependencyStyle,
is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(
dependency_style,
value_offset=10
)
async def lifespan_scoped_dependency(
param: Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)]
) -> AsyncGenerator[int, None]:
yield param
app = FastAPI()
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
int,
Depends(lifespan_scoped_dependency, dependency_scope="lifespan")
],
)
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency()
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
urls_and_responses=[("/test", 11)] * 2,
is_websocket=is_websocket
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("depends_class", [Depends, Security])
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
depends_class,
is_websocket
):
async def sub_dependency() -> None:
pass
async def dependency_func() -> None:
yield
async def override_dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None:
yield
app = FastAPI()
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")]
)
app.dependency_overrides[dependency_func] = override_dependency_func
with pytest.raises(DependencyScopeConflict):
with TestClient(app):
pass
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_bad_override_lifespan_scoped_dependencies(
use_cache,
dependency_style: DependencyStyle,
routing_style,
is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, should_error=True)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends]
)
if routing_style == "router_endpoint":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = override_dependency_factory.get_dependency()
with pytest.raises(IntentionallyBadDependency) as exception_info:
with TestClient(app):
pass
assert exception_info.value.args == (1,)

854
tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py

@ -0,0 +1,854 @@
import warnings
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, List, Tuple
import pytest
from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Cookie,
Depends,
FastAPI,
File,
Form,
Header,
Path,
Query,
)
from fastapi.exceptions import (
DependencyScopeConflict,
InvalidDependencyScope,
UninitializedLifespanDependency,
)
from fastapi.params import Security
from fastapi.security import SecurityScopes
from fastapi.testclient import TestClient
from typing_extensions import Annotated, Literal, assert_never
from tests.test_lifespan_scoped_dependencies.testing_utilities import (
DependencyFactory,
DependencyStyle,
IntentionallyBadDependency,
create_endpoint_0_annotations,
create_endpoint_1_annotation,
create_endpoint_2_annotations,
create_endpoint_3_annotations,
use_endpoint,
use_websocket,
)
def expect_correct_amount_of_dependency_activations(
*,
app: FastAPI,
dependency_factory: DependencyFactory,
urls_and_responses: List[Tuple[str, Any]],
expected_activation_times: int,
is_websocket: bool
) -> None:
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
with TestClient(app) as client:
assert dependency_factory.activation_times == expected_activation_times
assert dependency_factory.deactivation_times == 0
for url, expected_response in urls_and_responses:
if is_websocket:
assert use_websocket(client, url) == expected_response
else:
assert use_endpoint(client, url) == expected_response
assert dependency_factory.activation_times == expected_activation_times
assert dependency_factory.deactivation_times == 0
assert dependency_factory.activation_times == expected_activation_times
if dependency_factory.dependency_style not in (
DependencyStyle.SYNC_FUNCTION,
DependencyStyle.ASYNC_FUNCTION
):
assert dependency_factory.deactivation_times == expected_activation_times
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False], ids=["With Cache", "Without Cache"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoint_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket: bool,
):
dependency_factory = DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[None, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)],
expected_value=1
)
if routing_style == "router_endpoint":
app.include_router(router)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1,
is_websocket=is_websocket
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_duplication", [1, 2])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_router_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache,
dependency_duplication,
is_websocket: bool,
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
if routing_style == "app":
app = FastAPI(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations(
router=app,
path="/test",
is_websocket=is_websocket
)
else:
app = FastAPI()
router = APIRouter(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations(
router=router,
path="/test",
is_websocket=is_websocket
)
app.include_router(router)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", None)] * 2,
expected_activation_times=1 if use_cache else dependency_duplication,
is_websocket=is_websocket
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"])
def test_dependency_cache_in_same_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache,
main_dependency_scope: Literal["endpoint", "lifespan"],
is_websocket: bool,
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def dependency(
sub_dependency1: Annotated[int, depends],
sub_dependency2: Annotated[int, depends],
) -> List[int]:
return [sub_dependency1, sub_dependency2]
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[List[int], Depends(
dependency,
use_cache=use_cache,
dependency_scope=main_dependency_scope,
)]
)
if routing_style == "router":
app.include_router(router)
if use_cache:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 1]),
("/test", [1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket
)
else:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 2]),
("/test", [1, 2]),
],
dependency_factory=dependency_factory,
expected_activation_times=2,
is_websocket=is_websocket
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_same_endpoint(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
create_endpoint_3_annotations(
router=router,
path="/test",
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)]
)
if routing_style == "router":
app.include_router(router)
if use_cache:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 1, 1]),
("/test", [1, 1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket
)
else:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 2, 3]),
("/test", [1, 2, 3]),
],
dependency_factory=dependency_factory,
expected_activation_times=3,
is_websocket=is_websocket
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_different_endpoints(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
create_endpoint_3_annotations(
router=router,
path="/test1",
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)]
)
create_endpoint_3_annotations(
router=router,
path="/test2",
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)]
)
if routing_style == "router":
app.include_router(router)
if use_cache:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 1, 1]),
("/test2", [1, 1, 1]),
("/test1", [1, 1, 1]),
("/test2", [1, 1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket
)
else:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 2, 3]),
("/test2", [4, 5, 3]),
("/test1", [1, 2, 3]),
("/test2", [4, 5, 3]),
],
dependency_factory=dependency_factory,
expected_activation_times=5,
is_websocket=is_websocket
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_no_cached_dependency(
dependency_style: DependencyStyle,
routing_style,
is_websocket,
):
dependency_factory= DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=False
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
expected_value=1
)
if routing_style == "router":
app.include_router(router)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1,
is_websocket=is_websocket
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("annotation", [
Annotated[str, Path()],
Annotated[str, Body()],
Annotated[str, Query()],
Annotated[str, Header()],
SecurityScopes,
Annotated[str, Cookie()],
Annotated[str, Form()],
Annotated[str, File()],
BackgroundTasks,
])
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
annotation,
is_websocket
):
async def dependency_func(param: annotation) -> None:
yield
app = FastAPI()
with pytest.raises(DependencyScopeConflict):
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
None,
Depends(dependency_func, dependency_scope="lifespan")
],
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies(
dependency_style: DependencyStyle,
is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
async def lifespan_scoped_dependency(
param: Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)]
) -> AsyncGenerator[int, None]:
yield param
app = FastAPI()
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, Depends(lifespan_scoped_dependency)],
expected_value=1
)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
expected_activation_times=1,
urls_and_responses=[("/test", 1)] * 2,
is_websocket=is_websocket
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize([
"dependency_style",
"supports_teardown"
], [
(DependencyStyle.SYNC_FUNCTION, False),
(DependencyStyle.ASYNC_FUNCTION, False),
(DependencyStyle.SYNC_GENERATOR, True),
(DependencyStyle.ASYNC_GENERATOR, True),
])
def test_the_same_dependency_can_work_in_different_scopes(
dependency_style: DependencyStyle,
supports_teardown,
is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
app = FastAPI()
create_endpoint_2_annotations(
router=app,
path="/test",
is_websocket=is_websocket,
annotation1=Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="endpoint"
)],
annotation2=Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)],
)
if is_websocket:
get_response = use_websocket
else:
get_response = use_endpoint
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
with TestClient(app) as client:
assert dependency_factory.activation_times == 1
assert dependency_factory.deactivation_times == 0
assert get_response(client, "/test") == [2, 1]
assert dependency_factory.activation_times == 2
if supports_teardown:
assert dependency_factory.deactivation_times == 1
else:
assert dependency_factory.deactivation_times == 0
assert get_response(client, "/test") == [3, 1]
assert dependency_factory.activation_times == 3
if supports_teardown:
assert dependency_factory.deactivation_times == 2
else:
assert dependency_factory.deactivation_times == 0
assert dependency_factory.activation_times == 3
if supports_teardown:
assert dependency_factory.deactivation_times == 3
else:
assert dependency_factory.deactivation_times == 0
@pytest.mark.parametrize("lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"])
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
dependency_style: DependencyStyle,
is_websocket,
lifespan_style: Literal["lifespan_function", "lifespan_events"]
):
lifespan_started = False
lifespan_ended = False
if lifespan_style == "lifespan_generator":
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, int], None]:
nonlocal lifespan_started
nonlocal lifespan_ended
lifespan_started = True
yield
lifespan_ended = True
app = FastAPI(lifespan=lifespan)
elif lifespan_style == "events_decorator":
app = FastAPI()
with warnings.catch_warnings(action="ignore", category=DeprecationWarning):
@app.on_event("startup")
async def startup() -> None:
nonlocal lifespan_started
lifespan_started = True
@app.on_event("shutdown")
async def shutdown() -> None:
nonlocal lifespan_ended
lifespan_ended = True
elif lifespan_style == "events_constructor":
async def startup() -> None:
nonlocal lifespan_started
lifespan_started = True
async def shutdown() -> None:
nonlocal lifespan_ended
lifespan_ended = True
app = FastAPI(on_startup=[startup], on_shutdown=[shutdown])
else:
assert_never(lifespan_style)
dependency_factory = DependencyFactory(dependency_style)
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan"
)],
expected_value=1
)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
expected_activation_times=1,
urls_and_responses=[("/test", 1)] * 2,
is_websocket=is_websocket
)
assert lifespan_started and lifespan_ended
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("depends_class", [Depends, Security])
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
depends_class,
is_websocket
):
async def sub_dependency() -> None:
pass
async def dependency_func(param: Annotated[None, depends_class(sub_dependency)]) -> None:
yield
app = FastAPI()
with pytest.raises(DependencyScopeConflict):
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[None, Depends(dependency_func, dependency_scope="lifespan")],
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_dependencies_must_provide_correct_dependency_scope(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
with pytest.raises(
InvalidDependencyScope,
match=r'Dependency "value" of .* has an invalid scope: '
r'"incorrect"'
):
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[None, Depends(
dependency_factory.get_dependency(),
dependency_scope="incorrect",
use_cache=use_cache,
)]
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_incorrect_dependency_scope(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
# We intentionally change the dependency scope here to bypass the
# validation at the function level.
depends.dependency_scope = "asdad"
with pytest.raises(InvalidDependencyScope):
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends]
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
):
dependency_factory= DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
expected_value=1
)
if routing_style == "router_endpoint":
app.include_router(router)
with TestClient(app) as client:
dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"]
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {}
try:
with pytest.raises(UninitializedLifespanDependency):
if is_websocket:
with client.websocket_connect("/test"):
pass
else:
client.post("/test")
finally:
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = dependencies
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_internal_lifespan(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
expected_value=1
)
if routing_style == "router_endpoint":
app.include_router(router)
with TestClient(app) as client:
internal_state = client.app_state["__fastapi__"]
del client.app_state["__fastapi__"]
try:
with pytest.raises(UninitializedLifespanDependency):
if is_websocket:
with client.websocket_connect("/test"):
pass
else:
client.post("/test")
finally:
client.app_state["__fastapi__"] = internal_state
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Endpoint", "Websocket"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_bad_lifespan_scoped_dependencies(
use_cache,
dependency_style: DependencyStyle,
routing_style,
is_websocket
):
dependency_factory= DependencyFactory(dependency_style, should_error=True)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
expected_value=1
)
if routing_style == "router_endpoint":
app.include_router(router)
with pytest.raises(IntentionallyBadDependency) as exception_info:
with TestClient(app):
pass
assert exception_info.value.args == (1,)

202
tests/test_lifespan_scoped_dependencies/testing_utilities.py

@ -0,0 +1,202 @@
from enum import StrEnum, auto
from typing import Any, AsyncGenerator, Generator, TypeVar, Union, assert_never
from fastapi import APIRouter, FastAPI, WebSocket
from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
T = TypeVar('T')
class DependencyStyle(StrEnum):
SYNC_FUNCTION = auto()
ASYNC_FUNCTION = auto()
SYNC_GENERATOR = auto()
ASYNC_GENERATOR = auto()
class IntentionallyBadDependency(Exception):
pass
class DependencyFactory:
def __init__(
self,
dependency_style: DependencyStyle, *,
should_error: bool = False,
value_offset: int = 0,
):
self.activation_times = 0
self.deactivation_times = 0
self.dependency_style = dependency_style
self._should_error = should_error
self._value_offset = value_offset
def get_dependency(self):
if self.dependency_style == DependencyStyle.SYNC_FUNCTION:
return self._synchronous_function_dependency
if self.dependency_style == DependencyStyle.SYNC_GENERATOR:
return self._synchronous_generator_dependency
if self.dependency_style == DependencyStyle.ASYNC_FUNCTION:
return self._asynchronous_function_dependency
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR:
return self._asynchronous_generator_dependency
assert_never(self.dependency_style)
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]:
self.activation_times += 1
if self._should_error:
raise IntentionallyBadDependency(self.activation_times)
yield self.activation_times + self._value_offset
self.deactivation_times += 1
def _synchronous_generator_dependency(self) -> Generator[T, None, None]:
self.activation_times += 1
if self._should_error:
raise IntentionallyBadDependency(self.activation_times)
yield self.activation_times + self._value_offset
self.deactivation_times += 1
async def _asynchronous_function_dependency(self) -> T:
self.activation_times += 1
if self._should_error:
raise IntentionallyBadDependency(self.activation_times)
return self.activation_times + self._value_offset
def _synchronous_function_dependency(self) -> T:
self.activation_times += 1
if self._should_error:
raise IntentionallyBadDependency(self.activation_times)
return self.activation_times + self._value_offset
def use_endpoint(client: TestClient, url: str) -> Any:
response = client.post(url)
response.raise_for_status()
return response.json()
def use_websocket(client: TestClient, url: str) -> Any:
with client.websocket_connect(url) as connection:
return connection.receive_json()
def create_endpoint_0_annotations(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(websocket: WebSocket) -> None:
await websocket.accept()
try:
await websocket.send_json(None)
except WebSocketDisconnect:
pass
else:
@router.post(path)
async def endpoint() -> None:
return None
def create_endpoint_1_annotation(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
annotation: Any,
expected_value: Any = None
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(
websocket: WebSocket,
value: annotation
) -> None:
if expected_value is not None:
assert value == expected_value
await websocket.accept()
try:
await websocket.send_json(value)
except WebSocketDisconnect:
pass
else:
@router.post(path)
async def endpoint(
value: annotation
) -> None:
if expected_value is not None:
assert value == expected_value
return value
def create_endpoint_2_annotations(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
annotation1: Any,
annotation2: Any,
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(
websocket: WebSocket,
value1: annotation1,
value2: annotation2,
) -> None:
await websocket.accept()
try:
await websocket.send_json([value1, value2])
except WebSocketDisconnect:
await websocket.close()
else:
@router.post(path)
async def endpoint(
value1: annotation1,
value2: annotation2,
) -> list[Any]:
return [value1, value2]
def create_endpoint_3_annotations(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
annotation1: Any,
annotation2: Any,
annotation3: Any
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(
websocket: WebSocket,
value1: annotation1,
value2: annotation2,
value3: annotation3
) -> None:
await websocket.accept()
try:
await websocket.send_json([value1, value2, value3])
except WebSocketDisconnect:
await websocket.close()
else:
@router.post(path)
async def endpoint(
value1: annotation1,
value2: annotation2,
value3: annotation3
) -> list[Any]:
return [value1, value2, value3]
Loading…
Cancel
Save