Browse Source

Merge 731b93202c into 8af92a6139

pull/12529/merge
Nir Schulman 3 days ago
committed by GitHub
parent
commit
d1ceaa7a73
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 111
      docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md
  2. 1
      docs/en/mkdocs.yml
  3. 44
      docs_src/dependencies/tutorial013a.py
  4. 42
      docs_src/dependencies/tutorial013a_an_py39.py
  5. 65
      docs_src/dependencies/tutorial013b.py
  6. 61
      docs_src/dependencies/tutorial013b_an_py39.py
  7. 50
      docs_src/dependencies/tutorial013c.py
  8. 55
      docs_src/dependencies/tutorial013c_an_py39.py
  9. 40
      docs_src/dependencies/tutorial013d.py
  10. 43
      docs_src/dependencies/tutorial013d_an_py39.py
  11. 58
      fastapi/applications.py
  12. 68
      fastapi/dependencies/models.py
  13. 291
      fastapi/dependencies/utils.py
  14. 16
      fastapi/exceptions.py
  15. 44
      fastapi/lifespan.py
  16. 6
      fastapi/openapi/utils.py
  17. 34
      fastapi/param_functions.py
  18. 21
      fastapi/params.py
  19. 57
      fastapi/routing.py
  20. 0
      tests/test_lifespan_scoped_dependencies/__init__.py
  21. 624
      tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py
  22. 920
      tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py
  23. 193
      tests/test_lifespan_scoped_dependencies/testing_utilities.py
  24. 35
      tests/test_params_repr.py
  25. 8
      tests/test_router_events.py
  26. 65
      tests/test_tutorial/test_dependencies/test_tutorial013a.py
  27. 70
      tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py
  28. 130
      tests/test_tutorial/test_dependencies/test_tutorial013b.py
  29. 135
      tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py
  30. 78
      tests/test_tutorial/test_dependencies/test_tutorial013c.py
  31. 83
      tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py
  32. 76
      tests/test_tutorial/test_dependencies/test_tutorial013d.py
  33. 81
      tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py

111
docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md

@ -0,0 +1,111 @@
# Lifespan Scoped Dependencies
## Intro
So far we've used dependencies which are "endpoint scoped". Meaning, they are
called again and again for every incoming request to the endpoint. However,
this is not always ideal:
* Sometimes dependencies have a large setup/teardown time. Running it for every request will result in bad performance.
* Sometimes dependencies need to have their values shared throughout the lifespan
of the application between multiple requests.
An example of this would be a connection to a database. Databases are typically
less efficient when working with lots of connections and would prefer that
clients would create a single connection for their operations.
For such cases can be solved by using "lifespan scoped dependencies".
## What is a lifespan scoped dependency?
Lifespan scoped dependencies work similarly to the (endpoint scoped)
dependencies we've worked with so far. However, unlike endpoint scoped
dependencies, lifespan scoped dependencies are called once and only
once in the application's lifespan:
* During the application startup process, all lifespan scoped dependencies will
be called.
* Their returned value will be shared across all requests to the application.
* During the application's shutdown process, all lifespan scoped dependencies
will be gracefully teared down.
## Create a lifespan scoped dependency
You may declare a dependency as a lifespan scoped dependency by passing
`dependency_scope="lifespan"` to the `Depends` function:
{* ../../docs_src/dependencies/tutorial013a_an_py39.py *}
/// tip
In the example above we saved the annotation to a separate variable, and then
reused it in our endpoints. This is not a requirement, we could also declare
the exact same annotation in both endpoints. However, it is recommended that you
do save the annotation to a variable so you won't accidentally forget to pass
`dependency_scope="lifespan"` to some of the endpoints (Causing the endpoint
to create a new database connection for every request).
///
In this example, the `get_database_connection` dependency will be executed once,
during the application's startup. **FastAPI** will internally save the resulting
connection object, and whenever the `read_users` and `read_items` endpoints are
called, they will be using the previously saved connection. Once the application
shuts down, **FastAPI** will make sure to gracefully close the connection object.
## The `use_cache` argument
The `use_cache` argument works similarly to the way it worked with endpoint
scoped dependencies. Meaning as **FastAPI** gathers lifespan scoped dependencies, it
will cache dependencies it already encountered before. However, you can disable
this behavior by passing `use_cache=False` to `Depends`:
{* ../../docs_src/dependencies/tutorial013b_an_py39.py *}
In this example, the `read_users` and `read_groups` endpoints are using
`use_cache=False` whereas the `read_items` and `read_item` are using
`use_cache=True`.
That means that we'll have a total of 3 connections created
for the duration of the application's lifespan:
* One connection will be shared across all requests for the `read_items` and `read_item` endpoints.
* A second connection will be shared across all requests for the `read_users` endpoint.
* A third and final connection will be shared across all requests for the `read_groups` endpoint.
## Lifespan Scoped Sub-Dependencies
Just like with endpoint scoped dependencies, lifespan scoped dependencies may
use other lifespan scoped sub-dependencies themselves:
{* ../../docs_src/dependencies/tutorial013c_an_py39.py *}
Endpoint scoped dependencies may use lifespan scoped sub dependencies as well:
{* ../../docs_src/dependencies/tutorial013d_an_py39.py *}
/// note
You can pass `dependency_scope="endpoint"` if you wish to explicitly specify
that a dependency is endpoint scoped. It will work the same as not specifying
a dependency scope at all.
///
As you can see, regardless of the scope, dependencies can use lifespan scoped
sub-dependencies.
## Dependency Scope Conflicts
By definition, lifespan scoped dependencies are being setup in the application's
startup process, before any request is ever being made to any endpoint.
Therefore, it is not possible for a lifespan scoped dependency to use any
parameters that require the scope of an endpoint.
That includes but not limited to:
* Parts of the request (like `Body`, `Query` and `Path`)
* The request/response objects themselves (like `Request`, `Response` and `WebSocket`)
* Endpoint scoped sub-dependencies.
Defining a dependency with such parameters will raise an `InvalidDependencyScope` error.

1
docs/en/mkdocs.yml

@ -141,6 +141,7 @@ nav:
- tutorial/dependencies/dependencies-in-path-operation-decorators.md
- tutorial/dependencies/global-dependencies.md
- tutorial/dependencies/dependencies-with-yield.md
- tutorial/dependencies/lifespan-scoped-dependencies.md
- Security:
- tutorial/security/index.md
- tutorial/security/first-steps.md

44
docs_src/dependencies/tutorial013a.py

@ -0,0 +1,44 @@
from typing import List
from fastapi import Depends, FastAPI
from typing_extensions import Self
class MyDatabaseConnection:
"""
This is a mock just for example purposes.
"""
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
async def get_records(self, table_name: str) -> List[dict]:
pass
app = FastAPI()
async def get_database_connection():
async with MyDatabaseConnection() as connection:
yield connection
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan")
@app.get("/users/")
async def read_users(
database_connection: MyDatabaseConnection = GlobalDatabaseConnection,
):
return await database_connection.get_records("users")
@app.get("/items/")
async def read_items(
database_connection: MyDatabaseConnection = GlobalDatabaseConnection,
):
return await database_connection.get_records("items")

42
docs_src/dependencies/tutorial013a_an_py39.py

@ -0,0 +1,42 @@
from typing import Annotated
from fastapi import Depends, FastAPI
from typing_extensions import Self
class MyDatabaseConnection:
"""
This is a mock just for example purposes.
"""
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
async def get_records(self, table_name: str) -> list[dict]:
pass
app = FastAPI()
async def get_database_connection():
async with MyDatabaseConnection() as connection:
yield connection
GlobalDatabaseConnection = Annotated[
MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")
]
@app.get("/users/")
async def read_users(database_connection: GlobalDatabaseConnection):
return await database_connection.get_records("users")
@app.get("/items/")
async def read_items(database_connection: GlobalDatabaseConnection):
return await database_connection.get_records("items")

65
docs_src/dependencies/tutorial013b.py

@ -0,0 +1,65 @@
from typing import List
from fastapi import Depends, FastAPI, Path
from typing_extensions import Self
class MyDatabaseConnection:
"""
This is a mock just for example purposes.
"""
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
async def get_records(self, table_name: str) -> List[dict]:
pass
async def get_record(self, table_name: str, record_id: str) -> dict:
pass
app = FastAPI()
async def get_database_connection():
async with MyDatabaseConnection() as connection:
yield connection
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan")
DedicatedDatabaseConnection = Depends(
get_database_connection, dependency_scope="lifespan", use_cache=False
)
@app.get("/groups/")
async def read_groups(
database_connection: MyDatabaseConnection = DedicatedDatabaseConnection,
):
return await database_connection.get_records("groups")
@app.get("/users/")
async def read_users(
database_connection: MyDatabaseConnection = DedicatedDatabaseConnection,
):
return await database_connection.get_records("users")
@app.get("/items/")
async def read_items(
database_connection: MyDatabaseConnection = GlobalDatabaseConnection,
):
return await database_connection.get_records("items")
@app.get("/items/{item_id}")
async def read_item(
item_id: str = Path(),
database_connection: MyDatabaseConnection = GlobalDatabaseConnection,
):
return await database_connection.get_record("items", item_id)

61
docs_src/dependencies/tutorial013b_an_py39.py

@ -0,0 +1,61 @@
from typing import Annotated
from fastapi import Depends, FastAPI, Path
from typing_extensions import Self
class MyDatabaseConnection:
"""
This is a mock just for example purposes.
"""
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
async def get_records(self, table_name: str) -> list[dict]:
pass
async def get_record(self, table_name: str, record_id: str) -> dict:
pass
app = FastAPI()
async def get_database_connection():
async with MyDatabaseConnection() as connection:
yield connection
GlobalDatabaseConnection = Annotated[
MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")
]
DedicatedDatabaseConnection = Annotated[
MyDatabaseConnection,
Depends(get_database_connection, dependency_scope="lifespan", use_cache=False),
]
@app.get("/groups/")
async def read_groups(database_connection: DedicatedDatabaseConnection):
return await database_connection.get_records("groups")
@app.get("/users/")
async def read_users(database_connection: DedicatedDatabaseConnection):
return await database_connection.get_records("users")
@app.get("/items/")
async def read_items(database_connection: GlobalDatabaseConnection):
return await database_connection.get_records("items")
@app.get("/items/{item_id}")
async def read_item(
database_connection: GlobalDatabaseConnection, item_id: Annotated[str, Path()]
):
return await database_connection.get_record("items", item_id)

50
docs_src/dependencies/tutorial013c.py

@ -0,0 +1,50 @@
from dataclasses import dataclass
from fastapi import Depends, FastAPI, Path
from typing_extensions import Self
@dataclass
class MyDatabaseConnection:
"""
This is a mock just for example purposes.
"""
connection_string: str
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
async def get_record(self, table_name: str, record_id: str) -> dict:
pass
app = FastAPI()
async def get_configuration() -> dict:
return {
"database_url": "sqlite:///database.db",
}
GlobalConfiguration = Depends(get_configuration, dependency_scope="lifespan")
async def get_database_connection(configuration: dict = GlobalConfiguration):
async with MyDatabaseConnection(configuration["database_url"]) as connection:
yield connection
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan")
@app.get("/users/{user_id}")
async def read_user(
database_connection: MyDatabaseConnection = GlobalDatabaseConnection,
user_id: str = Path(),
):
return await database_connection.get_record("users", user_id)

55
docs_src/dependencies/tutorial013c_an_py39.py

@ -0,0 +1,55 @@
from dataclasses import dataclass
from typing import Annotated
from fastapi import Depends, FastAPI, Path
from typing_extensions import Self
@dataclass
class MyDatabaseConnection:
"""
This is a mock just for example purposes.
"""
connection_string: str
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
async def get_record(self, table_name: str, record_id: str) -> dict:
pass
app = FastAPI()
async def get_configuration() -> dict:
return {
"database_url": "sqlite:///database.db",
}
GlobalConfiguration = Annotated[
dict, Depends(get_configuration, dependency_scope="lifespan")
]
async def get_database_connection(configuration: GlobalConfiguration):
async with MyDatabaseConnection(configuration["database_url"]) as connection:
yield connection
GlobalDatabaseConnection = Annotated[
get_database_connection,
Depends(get_database_connection, dependency_scope="lifespan"),
]
@app.get("/users/{user_id}")
async def read_user(
database_connection: GlobalDatabaseConnection, user_id: Annotated[str, Path()]
):
return await database_connection.get_record("users", user_id)

40
docs_src/dependencies/tutorial013d.py

@ -0,0 +1,40 @@
from fastapi import Depends, FastAPI, Path
from typing_extensions import Self
class MyDatabaseConnection:
"""
This is a mock just for example purposes.
"""
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
async def get_record(self, table_name: str, record_id: str) -> dict:
pass
app = FastAPI()
async def get_database_connection():
async with MyDatabaseConnection() as connection:
yield connection
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan")
async def get_user_record(
database_connection: MyDatabaseConnection = GlobalDatabaseConnection,
user_id: str = Path(),
) -> dict:
return await database_connection.get_record("users", user_id)
@app.get("/users/{user_id}")
async def read_user(user_record: dict = Depends(get_user_record)):
return user_record

43
docs_src/dependencies/tutorial013d_an_py39.py

@ -0,0 +1,43 @@
from typing import Annotated
from fastapi import Depends, FastAPI, Path
from typing_extensions import Self
class MyDatabaseConnection:
"""
This is a mock just for example purposes.
"""
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
async def get_record(self, table_name: str, record_id: str) -> dict:
pass
app = FastAPI()
async def get_database_connection():
async with MyDatabaseConnection() as connection:
yield connection
GlobalDatabaseConnection = Annotated[
MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")
]
async def get_user_record(
database_connection: GlobalDatabaseConnection, user_id: Annotated[str, Path()]
) -> dict:
return await database_connection.get_record("users", user_id)
@app.get("/users/{user_id}")
async def read_user(user_record: Annotated[dict, Depends(get_user_record)]):
return user_record

58
fastapi/applications.py

@ -1,6 +1,8 @@
from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Coroutine,
@ -15,12 +17,14 @@ from typing import (
from fastapi import routing
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.utils import is_coroutine_callable
from fastapi.exception_handlers import (
http_exception_handler,
request_validation_exception_handler,
websocket_request_validation_exception_handler,
)
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.lifespan import resolve_lifespan_dependants
from fastapi.logger import logger
from fastapi.openapi.docs import (
get_redoc_html,
@ -29,9 +33,11 @@ from fastapi.openapi.docs import (
)
from fastapi.openapi.utils import get_openapi
from fastapi.params import Depends
from fastapi.routing import merge_lifespan_context
from fastapi.types import DecoratedCallable, IncEx
from fastapi.utils import generate_unique_id
from starlette.applications import Starlette
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import State
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
@ -929,12 +935,26 @@ class FastAPI(Starlette):
"""
),
] = {}
if lifespan is None:
lifespan = FastAPI._internal_lifespan
else:
lifespan = merge_lifespan_context(FastAPI._internal_lifespan, lifespan)
# Since we always use a lifespan, starlette will no longer run event
# handlers which are defined in the scope of the application.
# We therefore need to call them ourselves.
if on_startup is None:
on_startup = []
if on_shutdown is None:
on_shutdown = []
self._on_startup = list(on_startup)
self._on_shutdown = list(on_shutdown)
self.router: routing.APIRouter = routing.APIRouter(
routes=routes,
redirect_slashes=redirect_slashes,
dependency_overrides_provider=self,
on_startup=on_startup,
on_shutdown=on_shutdown,
lifespan=lifespan,
default_response_class=default_response_class,
dependencies=dependencies,
@ -963,6 +983,30 @@ class FastAPI(Starlette):
self.middleware_stack: Union[ASGIApp, None] = None
self.setup()
@asynccontextmanager
async def _internal_lifespan(self) -> AsyncGenerator[Dict[str, Any], None]:
async with AsyncExitStack() as exit_stack:
lifespan_scoped_dependencies = await resolve_lifespan_dependants(
app=self, async_exit_stack=exit_stack
)
try:
for handler in self._on_startup:
if is_coroutine_callable(handler):
await handler()
else:
await run_in_threadpool(handler)
yield {
"__fastapi__": {
"lifespan_scoped_dependencies": lifespan_scoped_dependencies
}
}
finally:
for handler in self._on_shutdown:
if is_coroutine_callable(handler):
await handler()
else:
await run_in_threadpool(handler)
def openapi(self) -> Dict[str, Any]:
"""
Generate the OpenAPI schema of the application. This is called by FastAPI
@ -4492,7 +4536,15 @@ class FastAPI(Starlette):
Read more about it in the
[FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/#alternative-events-deprecated).
"""
return self.router.on_event(event_type)
def decorator(func: DecoratedCallable) -> DecoratedCallable:
if event_type == "startup":
self._on_startup.append(func)
else:
self._on_shutdown.append(func)
return func
return decorator
def middleware(
self,

68
fastapi/dependencies/models.py

@ -1,8 +1,9 @@
from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Sequence, Tuple
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase
from typing_extensions import TypeAlias
@dataclass
@ -11,17 +12,53 @@ class SecurityRequirement:
scopes: Optional[Sequence[str]] = None
LifespanDependantCacheKey: TypeAlias = Union[
Tuple[Callable[..., Any], Union[str, int]], Callable[..., Any]
]
@dataclass
class LifespanDependant:
call: Callable[..., Any]
caller: Callable[..., Any]
dependencies: List["LifespanDependant"] = field(default_factory=list)
name: Optional[str] = None
use_cache: bool = True
index: Optional[int] = None
cache_key: LifespanDependantCacheKey = field(init=False)
def __post_init__(self) -> None:
if self.use_cache:
self.cache_key = self.call
elif self.name is not None:
self.cache_key = (self.caller, self.name)
else:
assert self.index is not None, (
"Lifespan dependency must have an associated name or index."
)
self.cache_key = (self.caller, self.index)
EndpointDependantCacheKey: TypeAlias = Tuple[
Optional[Callable[..., Any]], Tuple[str, ...]
]
@dataclass
class Dependant:
class EndpointDependant:
endpoint_dependencies: List["EndpointDependant"] = field(default_factory=list)
lifespan_dependencies: List[LifespanDependant] = field(default_factory=list)
name: Optional[str] = None
call: Optional[Callable[..., Any]] = None
use_cache: bool = True
index: Optional[int] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
path_params: List[ModelField] = field(default_factory=list)
query_params: List[ModelField] = field(default_factory=list)
header_params: List[ModelField] = field(default_factory=list)
cookie_params: List[ModelField] = field(default_factory=list)
body_params: List[ModelField] = field(default_factory=list)
dependencies: List["Dependant"] = field(default_factory=list)
security_requirements: List[SecurityRequirement] = field(default_factory=list)
name: Optional[str] = None
call: Optional[Callable[..., Any]] = None
request_param_name: Optional[str] = None
websocket_param_name: Optional[str] = None
http_connection_param_name: Optional[str] = None
@ -29,9 +66,26 @@ class Dependant:
background_tasks_param_name: Optional[str] = None
security_scopes_param_name: Optional[str] = None
security_scopes: Optional[List[str]] = None
use_cache: bool = True
path: Optional[str] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
def __post_init__(self) -> None:
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
# Kept for backwards compatibility
@property
def dependencies(self) -> Tuple[Union["EndpointDependant", LifespanDependant], ...]:
lifespan_dependencies = cast(
List[Union[EndpointDependant, LifespanDependant]],
self.lifespan_dependencies,
)
endpoint_dependencies = cast(
List[Union[EndpointDependant, LifespanDependant]],
self.endpoint_dependencies,
)
return tuple(lifespan_dependencies + endpoint_dependencies)
# Kept for backwards compatibility
Dependant = EndpointDependant
CacheKey: TypeAlias = Union[EndpointDependantCacheKey, LifespanDependantCacheKey]

291
fastapi/dependencies/utils.py

@ -51,7 +51,19 @@ from fastapi.concurrency import (
asynccontextmanager,
contextmanager_in_threadpool,
)
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.dependencies.models import (
CacheKey,
EndpointDependant,
EndpointDependantCacheKey,
LifespanDependant,
LifespanDependantCacheKey,
SecurityRequirement,
)
from fastapi.exceptions import (
DependencyScopeConflict,
InvalidDependencyScope,
UninitializedLifespanDependency,
)
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
@ -120,8 +132,9 @@ def get_param_sub_dependant(
param_name: str,
depends: params.Depends,
path: str,
caller: Callable[..., Any],
security_scopes: Optional[List[str]] = None,
) -> Dependant:
) -> Union[EndpointDependant, LifespanDependant]:
assert depends.dependency
return get_sub_dependant(
depends=depends,
@ -129,14 +142,23 @@ def get_param_sub_dependant(
path=path,
name=param_name,
security_scopes=security_scopes,
caller=caller,
)
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
def get_parameterless_sub_dependant(
*, depends: params.Depends, path: str, caller: Callable[..., Any], index: int
) -> Union[EndpointDependant, LifespanDependant]:
assert callable(depends.dependency), (
"A parameter-less dependency must have a callable dependency"
)
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
return get_sub_dependant(
depends=depends,
dependency=depends.dependency,
path=path,
caller=caller,
index=index,
)
def get_sub_dependant(
@ -144,57 +166,72 @@ def get_sub_dependant(
depends: params.Depends,
dependency: Callable[..., Any],
path: str,
caller: Callable[..., Any],
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
) -> Dependant:
security_requirement = None
security_scopes = security_scopes or []
if isinstance(depends, params.Security):
dependency_scopes = depends.scopes
security_scopes.extend(dependency_scopes)
if isinstance(dependency, SecurityBase):
use_scopes: List[str] = []
if isinstance(dependency, (OAuth2, OpenIdConnect)):
use_scopes = security_scopes
security_requirement = SecurityRequirement(
security_scheme=dependency, scopes=use_scopes
)
sub_dependant = get_dependant(
path=path,
call=dependency,
name=name,
security_scopes=security_scopes,
use_cache=depends.use_cache,
)
if security_requirement:
sub_dependant.security_requirements.append(security_requirement)
return sub_dependant
CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
index: Optional[int] = None,
) -> Union[EndpointDependant, LifespanDependant]:
if depends.dependency_scope == "lifespan":
return get_lifespan_dependant(
caller=caller,
call=dependency,
name=name,
use_cache=depends.use_cache,
index=index,
)
elif depends.dependency_scope == "endpoint":
security_requirement = None
security_scopes = security_scopes or []
if isinstance(depends, params.Security):
dependency_scopes = depends.scopes
security_scopes.extend(dependency_scopes)
if isinstance(dependency, SecurityBase):
use_scopes: List[str] = []
if isinstance(dependency, (OAuth2, OpenIdConnect)):
use_scopes = security_scopes
security_requirement = SecurityRequirement(
security_scheme=dependency, scopes=use_scopes
)
sub_dependant = get_endpoint_dependant(
path=path,
call=dependency,
name=name,
security_scopes=security_scopes,
use_cache=depends.use_cache,
index=index,
)
if security_requirement:
sub_dependant.security_requirements.append(security_requirement)
return sub_dependant
else:
raise InvalidDependencyScope(
f'Dependency "{name}" of {caller} has an invalid '
f'scope: "{depends.dependency_scope}"'
)
def get_flat_dependant(
dependant: Dependant,
dependant: EndpointDependant,
*,
skip_repeats: bool = False,
visited: Optional[List[CacheKey]] = None,
) -> Dependant:
) -> EndpointDependant:
if visited is None:
visited = []
visited.append(dependant.cache_key)
flat_dependant = Dependant(
flat_dependant = EndpointDependant(
path_params=dependant.path_params.copy(),
query_params=dependant.query_params.copy(),
header_params=dependant.header_params.copy(),
cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(),
security_requirements=dependant.security_requirements.copy(),
lifespan_dependencies=dependant.lifespan_dependencies.copy(),
use_cache=dependant.use_cache,
path=dependant.path,
)
for sub_dependant in dependant.dependencies:
for sub_dependant in dependant.endpoint_dependencies:
if skip_repeats and sub_dependant.cache_key in visited:
continue
flat_sub = get_flat_dependant(
@ -206,6 +243,7 @@ def get_flat_dependant(
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
flat_dependant.body_params.extend(flat_sub.body_params)
flat_dependant.security_requirements.extend(flat_sub.security_requirements)
flat_dependant.lifespan_dependencies.extend(flat_sub.lifespan_dependencies)
return flat_dependant
@ -219,7 +257,7 @@ def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
return fields
def get_flat_params(dependant: Dependant) -> List[ModelField]:
def get_flat_params(dependant: EndpointDependant) -> List[ModelField]:
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
@ -262,23 +300,75 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
return get_typed_annotation(annotation, globalns)
def get_dependant(
def get_lifespan_dependant(
*,
caller: Callable[..., Any],
call: Callable[..., Any],
name: Optional[str] = None,
use_cache: bool = True,
index: Optional[int] = None,
) -> LifespanDependant:
dependency_signature = get_typed_signature(call)
signature_params = dependency_signature.parameters
dependant = LifespanDependant(
call=call, name=name, use_cache=use_cache, caller=caller, index=index
)
for param_name, param in signature_params.items():
param_details = analyze_param(
param_name=param_name,
annotation=param.annotation,
value=param.default,
is_path_param=False,
)
if param_details.depends is None:
raise DependencyScopeConflict(
f'Lifespan scoped dependency "{dependant.name}" was defined '
f'with an invalid argument: "{param_name}" which is '
f'"endpoint" scoped. Lifespan scoped dependencies may only '
f"use lifespan scoped sub-dependencies."
)
if param_details.depends.dependency_scope != "lifespan":
raise DependencyScopeConflict(
f"Lifespan scoped dependency {dependant.name} was defined with the "
f'sub-dependency "{param_name}" which is '
f'"{param_details.depends.dependency_scope}" scoped. '
f"Lifespan scoped dependencies may only use lifespan scoped "
f"sub-dependencies."
)
assert param_details.depends.dependency is not None
sub_dependant = get_lifespan_dependant(
name=param_name,
call=param_details.depends.dependency,
use_cache=param_details.depends.use_cache,
caller=call,
)
dependant.dependencies.append(sub_dependant)
return dependant
def get_endpoint_dependant(
*,
path: str,
call: Callable[..., Any],
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
) -> Dependant:
index: Optional[int] = None,
) -> EndpointDependant:
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters
dependant = Dependant(
dependant = EndpointDependant(
call=call,
name=name,
path=path,
security_scopes=security_scopes,
use_cache=use_cache,
index=index,
)
for param_name, param in signature_params.items():
is_path_param = param_name in path_param_names
@ -294,8 +384,13 @@ def get_dependant(
depends=param_details.depends,
path=path,
security_scopes=security_scopes,
caller=call,
)
dependant.dependencies.append(sub_dependant)
if isinstance(sub_dependant, EndpointDependant):
dependant.endpoint_dependencies.append(sub_dependant)
else:
assert isinstance(sub_dependant, LifespanDependant)
dependant.lifespan_dependencies.append(sub_dependant)
continue
if add_non_field_param_to_dependency(
param_name=param_name,
@ -314,8 +409,12 @@ def get_dependant(
return dependant
# Kept for backwards compatibility
get_dependant = get_endpoint_dependant
def add_non_field_param_to_dependency(
*, param_name: str, type_annotation: Any, dependant: Dependant
*, param_name: str, type_annotation: Any, dependant: EndpointDependant
) -> Optional[bool]:
if lenient_issubclass(type_annotation, Request):
dependant.request_param_name = param_name
@ -511,7 +610,7 @@ def analyze_param(
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
def add_param_to_fields(*, field: ModelField, dependant: EndpointDependant) -> None:
field_info = field.field_info
field_info_in = getattr(field_info, "in_", None)
if field_info_in == params.ParamTypes.path:
@ -560,36 +659,132 @@ async def solve_generator(
return await stack.enter_async_context(cm)
@dataclass
class SolvedLifespanDependant:
value: Any
dependency_cache: Dict[LifespanDependantCacheKey, Any]
async def solve_lifespan_dependant(
*,
dependant: LifespanDependant,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[
Dict[LifespanDependantCacheKey, Callable[..., Any]]
] = None,
async_exit_stack: AsyncExitStack,
) -> SolvedLifespanDependant:
dependency_cache = dependency_cache or {}
if dependant.use_cache and dependant.cache_key in dependency_cache:
return SolvedLifespanDependant(
value=dependency_cache[dependant.cache_key],
dependency_cache=dependency_cache,
)
call = dependant.call
dependant_to_solve = dependant
if (
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
):
call = getattr(dependency_overrides_provider, "dependency_overrides", {}).get(
dependant.call, dependant.call
)
dependant_to_solve = get_lifespan_dependant(
caller=dependant.caller,
call=call,
name=dependant.name,
use_cache=dependant.use_cache,
index=dependant.index,
)
dependency_arguments: Dict[str, Any] = {}
for sub_dependant in dependant_to_solve.dependencies:
assert sub_dependant.name, (
"Lifespan scoped dependencies should not be able to have "
"subdependencies with no name"
)
solved_sub_dependant = await solve_lifespan_dependant(
dependant=sub_dependant,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack,
)
dependency_cache.update(solved_sub_dependant.dependency_cache)
dependency_arguments[sub_dependant.name] = solved_sub_dependant.value
if is_gen_callable(call) or is_async_gen_callable(call):
value = await solve_generator(
call=call, stack=async_exit_stack, sub_values=dependency_arguments
)
elif is_coroutine_callable(call):
value = await call(**dependency_arguments)
else:
value = await run_in_threadpool(call, **dependency_arguments)
if dependant.cache_key not in dependency_cache:
dependency_cache[dependant.cache_key] = value
return SolvedLifespanDependant(
value=value,
dependency_cache=dependency_cache,
)
@dataclass
class SolvedDependency:
values: Dict[str, Any]
errors: List[Any]
background_tasks: Optional[StarletteBackgroundTasks]
response: Response
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
dependency_cache: Dict[EndpointDependantCacheKey, Any]
async def solve_dependencies(
*,
request: Union[Request, WebSocket],
dependant: Dependant,
dependant: EndpointDependant,
body: Optional[Union[Dict[str, Any], FormData]] = None,
background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
dependency_cache: Optional[Dict[EndpointDependantCacheKey, Any]] = None,
async_exit_stack: AsyncExitStack,
embed_body_fields: bool,
) -> SolvedDependency:
values: Dict[str, Any] = {}
errors: List[Any] = []
for lifespan_sub_dependant in dependant.lifespan_dependencies:
if lifespan_sub_dependant.name is None:
continue
try:
lifespan_scoped_dependencies = request.state.__fastapi__[
"lifespan_scoped_dependencies"
]
except (AttributeError, KeyError) as e:
raise UninitializedLifespanDependency(
"FastAPI's internal lifespan was not initialized correctly."
) from e
try:
value = lifespan_scoped_dependencies[lifespan_sub_dependant.cache_key]
except KeyError as e:
raise UninitializedLifespanDependency(
f'Dependency "{lifespan_sub_dependant.name}" of '
f"`{dependant.call}` was not initialized correctly."
) from e
values[lifespan_sub_dependant.name] = value
if response is None:
response = Response()
del response.headers["content-length"]
response.status_code = None # type: ignore
dependency_cache = dependency_cache or {}
sub_dependant: Dependant
for sub_dependant in dependant.dependencies:
for sub_dependant in dependant.endpoint_dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast(
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
@ -605,7 +800,7 @@ async def solve_dependencies(
dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call)
use_path: str = sub_dependant.path # type: ignore
use_sub_dependant = get_dependant(
use_sub_dependant = get_endpoint_dependant(
path=use_path,
call=call,
name=sub_dependant.name,
@ -949,7 +1144,7 @@ async def request_body_to_args(
def get_body_field(
*, flat_dependant: Dependant, name: str, embed_body_fields: bool
*, flat_dependant: EndpointDependant, name: str, embed_body_fields: bool
) -> Optional[ModelField]:
"""
Get a ModelField representing the request body for a path operation, combining

16
fastapi/exceptions.py

@ -146,6 +146,22 @@ class FastAPIError(RuntimeError):
"""
class DependencyError(FastAPIError):
pass
class InvalidDependencyScope(DependencyError):
pass
class DependencyScopeConflict(DependencyError):
pass
class UninitializedLifespanDependency(DependencyError):
pass
class ValidationException(Exception):
def __init__(self, errors: Sequence[Any]) -> None:
self._errors = errors

44
fastapi/lifespan.py

@ -0,0 +1,44 @@
from __future__ import annotations
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Callable, Dict, List
from fastapi.dependencies.models import LifespanDependant, LifespanDependantCacheKey
from fastapi.dependencies.utils import solve_lifespan_dependant
from fastapi.routing import APIRoute, APIWebSocketRoute
if TYPE_CHECKING: # pragma: nocover
from fastapi import FastAPI
def _get_lifespan_dependants(app: FastAPI) -> List[LifespanDependant]:
lifespan_dependants_cache: Dict[LifespanDependantCacheKey, LifespanDependant] = {}
for route in app.router.routes:
if not isinstance(route, (APIWebSocketRoute, APIRoute)):
continue
for sub_dependant in route.lifespan_dependencies:
if sub_dependant.cache_key in lifespan_dependants_cache:
continue
lifespan_dependants_cache[sub_dependant.cache_key] = sub_dependant
return list(lifespan_dependants_cache.values())
async def resolve_lifespan_dependants(
*, app: FastAPI, async_exit_stack: AsyncExitStack
) -> Dict[LifespanDependantCacheKey, Callable[..., Any]]:
lifespan_dependants = _get_lifespan_dependants(app)
dependency_cache: Dict[LifespanDependantCacheKey, Callable[..., Any]] = {}
for lifespan_dependant in lifespan_dependants:
solved_dependency = await solve_lifespan_dependant(
dependant=lifespan_dependant,
dependency_overrides_provider=app,
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack,
)
dependency_cache.update(solved_dependency.dependency_cache)
return dependency_cache

6
fastapi/openapi/utils.py

@ -15,7 +15,7 @@ from fastapi._compat import (
lenient_issubclass,
)
from fastapi.datastructures import DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.models import EndpointDependant
from fastapi.dependencies.utils import (
_get_flat_fields_from_params,
get_flat_dependant,
@ -76,7 +76,7 @@ status_code_ranges: Dict[str, str] = {
def get_openapi_security_definitions(
flat_dependant: Dependant,
flat_dependant: EndpointDependant,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
security_definitions = {}
operation_security = []
@ -94,7 +94,7 @@ def get_openapi_security_definitions(
def _get_openapi_operation_parameters(
*,
dependant: Dependant,
dependant: EndpointDependant,
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
field_mapping: Dict[

34
fastapi/param_functions.py

@ -1,8 +1,11 @@
from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from fastapi import params
from fastapi._compat import Undefined
from fastapi.openapi.models import Example
from fastapi.params import DependencyScope
from typing_extensions import Annotated, Doc, deprecated
_Unset: Any = Undefined
@ -2244,6 +2247,33 @@ def Depends( # noqa: N802
"""
),
] = True,
dependency_scope: Annotated[
DependencyScope,
Doc(
"""
The scope in which the dependency value should be evaluated. Can be
either `"endpoint"` or `"lifespan"`.
If `dependency_scope` is set to "endpoint" (the default), the
dependency will be setup and teardown for every request.
If `dependency_scope` is set to `"lifespan"` the dependency would
be setup at the start of the entire application's lifespan. The
evaluated dependency would be then reused across all endpoints.
The dependency would be teared down as a part of the application's
shutdown process.
Note that dependencies defined with the `"endpoint"` scope may use
sub-dependencies defined with the `"lifespan"` scope, but not the
other way around;
Dependencies defined with the `"lifespan"` scope may not use
sub-dependencies with `"endpoint"` scope, nor can they use
other "endpoint scoped" arguments such as "Path", "Body", "Query",
or any other annotation which does not make sense in a scope of an
application's entire lifespan.
"""
),
] = "endpoint",
) -> Any:
"""
Declare a FastAPI dependency.
@ -2274,7 +2304,9 @@ def Depends( # noqa: N802
return commons
```
"""
return params.Depends(dependency=dependency, use_cache=use_cache)
return params.Depends(
dependency=dependency, use_cache=use_cache, dependency_scope=dependency_scope
)
def Security( # noqa: N802

21
fastapi/params.py

@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from fastapi.openapi.models import Example
from pydantic.fields import FieldInfo
from typing_extensions import Annotated, deprecated
from typing_extensions import Annotated, Literal, TypeAlias, deprecated
from ._compat import (
PYDANTIC_V2,
@ -13,6 +13,7 @@ from ._compat import (
)
_Unset: Any = Undefined
DependencyScope: TypeAlias = Literal["endpoint", "lifespan"]
class ParamTypes(Enum):
@ -763,15 +764,25 @@ class File(Form):
class Depends:
def __init__(
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
self,
dependency: Optional[Callable[..., Any]] = None,
*,
use_cache: bool = True,
dependency_scope: DependencyScope = "endpoint",
):
self.dependency = dependency
self.use_cache = use_cache
self.dependency_scope = dependency_scope
def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({attr}{cache})"
if self.dependency_scope == "endpoint":
dependency_scope = ""
else:
dependency_scope = f', dependency_scope="{self.dependency_scope}"'
return f"{self.__class__.__name__}({attr}{cache}{dependency_scope})"
class Security(Depends):
@ -782,5 +793,7 @@ class Security(Depends):
scopes: Optional[Sequence[str]] = None,
use_cache: bool = True,
):
super().__init__(dependency=dependency, use_cache=use_cache)
super().__init__(
dependency=dependency, use_cache=use_cache, dependency_scope="endpoint"
)
self.scopes = scopes or []

57
fastapi/routing.py

@ -32,11 +32,11 @@ from fastapi._compat import (
lenient_issubclass,
)
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.models import EndpointDependant, LifespanDependant
from fastapi.dependencies.utils import (
_should_embed_body_fields,
get_body_field,
get_dependant,
get_endpoint_dependant,
get_flat_dependant,
get_parameterless_sub_dependant,
get_typed_return_annotation,
@ -124,7 +124,7 @@ def _prepare_response_content(
return res
def _merge_lifespan_context(
def merge_lifespan_context(
original_context: Lifespan[Any], nested_context: Lifespan[Any]
) -> Lifespan[Any]:
@asynccontextmanager
@ -203,7 +203,7 @@ async def serialize_response(
async def run_endpoint_function(
*, dependant: Dependant, values: Dict[str, Any], is_coroutine: bool
*, dependant: EndpointDependant, values: Dict[str, Any], is_coroutine: bool
) -> Any:
# Only called by get_request_handler. Has been split into its own function to
# facilitate profiling endpoints, since inner functions are harder to profile.
@ -216,7 +216,7 @@ async def run_endpoint_function(
def get_request_handler(
dependant: Dependant,
dependant: EndpointDependant,
body_field: Optional[ModelField] = None,
status_code: Optional[int] = None,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse),
@ -359,7 +359,7 @@ def get_request_handler(
def get_websocket_app(
dependant: Dependant,
dependant: EndpointDependant,
dependency_overrides_provider: Optional[Any] = None,
embed_body_fields: bool = False,
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
@ -401,12 +401,20 @@ class APIWebSocketRoute(routing.WebSocketRoute):
self.name = get_name(endpoint) if name is None else name
self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
self.dependant = get_endpoint_dependant(
path=self.path_format, call=self.endpoint
)
for i, depends in list(enumerate(self.dependencies))[::-1]:
sub_dependant = get_parameterless_sub_dependant(
depends=depends, path=self.path_format, caller=self.__call__, index=i
)
if isinstance(sub_dependant, EndpointDependant):
assert isinstance(sub_dependant, EndpointDependant)
self.dependant.endpoint_dependencies.insert(0, sub_dependant)
else:
assert isinstance(sub_dependant, LifespanDependant)
self.dependant.lifespan_dependencies.insert(0, sub_dependant)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params
@ -425,6 +433,10 @@ class APIWebSocketRoute(routing.WebSocketRoute):
child_scope["route"] = self
return match, child_scope
@property
def lifespan_dependencies(self) -> List[LifespanDependant]:
return self._flat_dependant.lifespan_dependencies
class APIRoute(routing.Route):
def __init__(
@ -552,12 +564,19 @@ class APIRoute(routing.Route):
self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable"
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
self.dependant = get_endpoint_dependant(
path=self.path_format, call=self.endpoint
)
for i, depends in list(enumerate(self.dependencies))[::-1]:
sub_dependant = get_parameterless_sub_dependant(
depends=depends, path=self.path_format, caller=self.__call__, index=i
)
if isinstance(sub_dependant, EndpointDependant):
self.dependant.endpoint_dependencies.insert(0, sub_dependant)
else:
assert isinstance(sub_dependant, LifespanDependant)
self.dependant.lifespan_dependencies.insert(0, sub_dependant)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params
@ -592,6 +611,10 @@ class APIRoute(routing.Route):
child_scope["route"] = self
return match, child_scope
@property
def lifespan_dependencies(self) -> List[LifespanDependant]:
return self._flat_dependant.lifespan_dependencies
class APIRouter(routing.Router):
"""
@ -1359,7 +1382,7 @@ class APIRouter(routing.Router):
self.add_event_handler("startup", handler)
for handler in router.on_shutdown:
self.add_event_handler("shutdown", handler)
self.lifespan_context = _merge_lifespan_context(
self.lifespan_context = merge_lifespan_context(
self.lifespan_context,
router.lifespan_context,
)

0
tests/test_lifespan_scoped_dependencies/__init__.py

624
tests/test_lifespan_scoped_dependencies/test_dependency_overrides.py

@ -0,0 +1,624 @@
from typing import Any, AsyncGenerator, List, Tuple
import pytest
from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Cookie,
Depends,
FastAPI,
File,
Form,
Header,
Path,
Query,
Request,
WebSocket,
)
from fastapi.exceptions import DependencyScopeConflict
from fastapi.params import Security
from fastapi.security import SecurityScopes
from fastapi.testclient import TestClient
from typing_extensions import Annotated, Literal
from tests.test_lifespan_scoped_dependencies.testing_utilities import (
DependencyFactory,
DependencyStyle,
IntentionallyBadDependency,
create_endpoint_0_annotations,
create_endpoint_1_annotation,
create_endpoint_3_annotations,
use_endpoint,
use_websocket,
)
def expect_correct_amount_of_dependency_activations(
*,
app: FastAPI,
dependency_factory: DependencyFactory,
override_dependency_factory: DependencyFactory,
urls_and_responses: List[Tuple[str, Any]],
expected_activation_times: int,
is_websocket: bool,
) -> None:
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
assert override_dependency_factory.activation_times == 0
assert override_dependency_factory.deactivation_times == 0
with TestClient(app) as client:
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
assert override_dependency_factory.activation_times == expected_activation_times
assert override_dependency_factory.deactivation_times == 0
for url, expected_response in urls_and_responses:
if is_websocket:
response = use_websocket(client, url)
else:
response = use_endpoint(client, url)
assert response == expected_response
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
assert (
override_dependency_factory.activation_times
== expected_activation_times
)
assert override_dependency_factory.deactivation_times == 0
assert dependency_factory.activation_times == 0
assert override_dependency_factory.activation_times == expected_activation_times
if dependency_factory.dependency_style not in (
DependencyStyle.SYNC_FUNCTION,
DependencyStyle.ASYNC_FUNCTION,
):
assert dependency_factory.deactivation_times == 0
assert (
override_dependency_factory.deactivation_times == expected_activation_times
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoint_dependencies(
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
None,
Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
),
],
expected_value=11,
)
if routing_style == "router_endpoint":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
urls_and_responses=[("/test", 11)] * 2,
expected_activation_times=1,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("dependency_duplication", [1, 2])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_router_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache,
dependency_duplication,
is_websocket,
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
if routing_style == "app":
app = FastAPI(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations(
router=app, path="/test", is_websocket=is_websocket
)
else:
app = FastAPI()
router = APIRouter(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations(
router=router, path="/test", is_websocket=is_websocket
)
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
urls_and_responses=[("/test", None)] * 2,
expected_activation_times=1 if use_cache else dependency_duplication,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"])
def test_dependency_cache_in_same_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache,
main_dependency_scope: Literal["endpoint", "lifespan"],
is_websocket,
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def dependency(
sub_dependency1: Annotated[int, depends],
sub_dependency2: Annotated[int, depends],
) -> List[int]:
return [sub_dependency1, sub_dependency2]
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
List[int],
Depends(
dependency,
use_cache=use_cache,
dependency_scope=main_dependency_scope,
),
],
)
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
if use_cache:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [11, 11]),
("/test", [11, 11]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket,
)
else:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [11, 12]),
("/test", [11, 12]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=2,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_same_endpoint(
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
create_endpoint_3_annotations(
router=router,
path="/test1",
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)],
)
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
if use_cache:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [11, 11, 11]),
("/test1", [11, 11, 11]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket,
)
else:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [11, 12, 13]),
("/test1", [11, 12, 13]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=3,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_different_endpoints(
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
create_endpoint_3_annotations(
router=router,
path="/test1",
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)],
)
create_endpoint_3_annotations(
router=router,
path="/test2",
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)],
)
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
if use_cache:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [11, 11, 11]),
("/test2", [11, 11, 11]),
("/test1", [11, 11, 11]),
("/test2", [11, 11, 11]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket,
)
else:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [11, 12, 13]),
("/test2", [14, 15, 13]),
("/test1", [11, 12, 13]),
("/test2", [14, 15, 13]),
],
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=5,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_no_cached_dependency(
dependency_style: DependencyStyle, routing_style, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=False,
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
)
if routing_style == "router":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
urls_and_responses=[("/test", 11)] * 2,
expected_activation_times=1,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize(
"annotation",
[
Annotated[str, Path()],
Annotated[str, Body()],
Annotated[str, Query()],
Annotated[str, Header()],
SecurityScopes,
Annotated[str, Cookie()],
Annotated[str, Form()],
Annotated[str, File()],
BackgroundTasks,
Request,
WebSocket,
],
)
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
annotation, is_websocket
):
async def dependency_func() -> None:
yield # pragma: nocover
async def override_dependency_func(param: annotation) -> None:
yield # pragma: nocover
app = FastAPI()
app.dependency_overrides[dependency_func] = override_dependency_func
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
None, Depends(dependency_func, dependency_scope="lifespan")
],
)
with pytest.raises(DependencyScopeConflict):
with TestClient(app):
pass
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies(
dependency_style: DependencyStyle, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10)
async def lifespan_scoped_dependency(
param: Annotated[
int,
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"),
],
) -> AsyncGenerator[int, None]:
yield param
app = FastAPI()
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
int, Depends(lifespan_scoped_dependency, dependency_scope="lifespan")
],
)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
override_dependency_factory=override_dependency_factory,
expected_activation_times=1,
urls_and_responses=[("/test", 11)] * 2,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("depends_class", [Depends, Security])
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
depends_class, is_websocket
):
async def sub_dependency() -> None:
pass # pragma: nocover
async def dependency_func() -> None:
yield # pragma: nocover
async def override_dependency_func(
param: Annotated[None, depends_class(sub_dependency)],
) -> None:
yield # pragma: nocover
app = FastAPI()
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
None, Depends(dependency_func, dependency_scope="lifespan")
],
)
app.dependency_overrides[dependency_func] = override_dependency_func
with pytest.raises(DependencyScopeConflict):
with TestClient(app):
pass
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_bad_override_lifespan_scoped_dependencies(
use_cache, dependency_style: DependencyStyle, routing_style, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
override_dependency_factory = DependencyFactory(dependency_style, should_error=True)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
)
if routing_style == "router_endpoint":
app.include_router(router)
app.dependency_overrides[dependency_factory.get_dependency()] = (
override_dependency_factory.get_dependency()
)
with pytest.raises(IntentionallyBadDependency) as exception_info:
with TestClient(app):
pass
assert exception_info.value.args == (1,)

920
tests/test_lifespan_scoped_dependencies/test_endpoint_usage.py

@ -0,0 +1,920 @@
import warnings
from contextlib import asynccontextmanager
from time import sleep
from typing import Any, AsyncGenerator, Dict, List, Tuple
import pytest
from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Cookie,
Depends,
FastAPI,
File,
Form,
Header,
Path,
Query,
Request,
WebSocket,
)
from fastapi.dependencies.utils import get_endpoint_dependant
from fastapi.exceptions import (
DependencyScopeConflict,
InvalidDependencyScope,
UninitializedLifespanDependency,
)
from fastapi.params import Security
from fastapi.security import SecurityScopes
from fastapi.testclient import TestClient
from typing_extensions import Annotated, Literal
from tests.test_lifespan_scoped_dependencies.testing_utilities import (
DependencyFactory,
DependencyStyle,
IntentionallyBadDependency,
create_endpoint_0_annotations,
create_endpoint_1_annotation,
create_endpoint_2_annotations,
create_endpoint_3_annotations,
use_endpoint,
use_websocket,
)
def expect_correct_amount_of_dependency_activations(
*,
app: FastAPI,
dependency_factory: DependencyFactory,
urls_and_responses: List[Tuple[str, Any]],
expected_activation_times: int,
is_websocket: bool,
) -> None:
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
with TestClient(app) as client:
assert dependency_factory.activation_times == expected_activation_times
assert dependency_factory.deactivation_times == 0
for url, expected_response in urls_and_responses:
if is_websocket:
assert use_websocket(client, url) == expected_response
else:
assert use_endpoint(client, url) == expected_response
assert dependency_factory.activation_times == expected_activation_times
assert dependency_factory.deactivation_times == 0
assert dependency_factory.activation_times == expected_activation_times
if dependency_factory.dependency_style not in (
DependencyStyle.SYNC_FUNCTION,
DependencyStyle.ASYNC_FUNCTION,
):
assert dependency_factory.deactivation_times == expected_activation_times
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize(
"use_cache", [True, False], ids=["With Cache", "Without Cache"]
)
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoint_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache,
is_websocket: bool,
):
dependency_factory = DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
None,
Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
),
],
expected_value=1,
)
if routing_style == "router_endpoint":
app.include_router(router)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("dependency_duplication", [1, 2])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_router_dependencies(
dependency_style: DependencyStyle,
routing_style,
use_cache,
dependency_duplication,
is_websocket: bool,
):
dependency_factory = DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
if routing_style == "app":
app = FastAPI(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations(
router=app, path="/test", is_websocket=is_websocket
)
else:
app = FastAPI()
router = APIRouter(dependencies=[depends] * dependency_duplication)
create_endpoint_0_annotations(
router=router, path="/test", is_websocket=is_websocket
)
app.include_router(router)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", None)] * 2,
expected_activation_times=1 if use_cache else dependency_duplication,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"])
def test_dependency_cache_in_same_dependency(
dependency_style: DependencyStyle,
routing_style,
use_cache,
main_dependency_scope: Literal["endpoint", "lifespan"],
is_websocket: bool,
):
dependency_factory = DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def dependency(
sub_dependency1: Annotated[int, depends],
sub_dependency2: Annotated[int, depends],
) -> List[int]:
return [sub_dependency1, sub_dependency2]
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
List[int],
Depends(
dependency,
use_cache=use_cache,
dependency_scope=main_dependency_scope,
),
],
)
if routing_style == "router":
app.include_router(router)
if use_cache:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 1]),
("/test", [1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket,
)
else:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 2]),
("/test", [1, 2]),
],
dependency_factory=dependency_factory,
expected_activation_times=2,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_same_endpoint(
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
create_endpoint_3_annotations(
router=router,
path="/test",
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)],
)
if routing_style == "router":
app.include_router(router)
if use_cache:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 1, 1]),
("/test", [1, 1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket,
)
else:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test", [1, 2, 3]),
("/test", [1, 2, 3]),
],
dependency_factory=dependency_factory,
expected_activation_times=3,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_dependency_cache_in_different_endpoints(
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int:
return dependency3
create_endpoint_3_annotations(
router=router,
path="/test1",
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)],
)
create_endpoint_3_annotations(
router=router,
path="/test2",
is_websocket=is_websocket,
annotation1=Annotated[int, depends],
annotation2=Annotated[int, depends],
annotation3=Annotated[int, Depends(endpoint_dependency)],
)
if routing_style == "router":
app.include_router(router)
if use_cache:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 1, 1]),
("/test2", [1, 1, 1]),
("/test1", [1, 1, 1]),
("/test2", [1, 1, 1]),
],
dependency_factory=dependency_factory,
expected_activation_times=1,
is_websocket=is_websocket,
)
else:
expect_correct_amount_of_dependency_activations(
app=app,
urls_and_responses=[
("/test1", [1, 2, 3]),
("/test2", [4, 5, 3]),
("/test1", [1, 2, 3]),
("/test2", [4, 5, 3]),
],
dependency_factory=dependency_factory,
expected_activation_times=5,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_no_cached_dependency(
dependency_style: DependencyStyle,
routing_style,
is_websocket,
):
dependency_factory = DependencyFactory(dependency_style)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=False,
)
app = FastAPI()
if routing_style == "app":
router = app
else:
router = APIRouter()
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
expected_value=1,
)
if routing_style == "router":
app.include_router(router)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
urls_and_responses=[("/test", 1)] * 2,
expected_activation_times=1,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize(
"annotation",
[
Annotated[str, Path()],
Annotated[str, Body()],
Annotated[str, Query()],
Annotated[str, Header()],
SecurityScopes,
Annotated[str, Cookie()],
Annotated[str, Form()],
Annotated[str, File()],
BackgroundTasks,
Request,
WebSocket,
],
)
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters(
annotation, is_websocket
):
async def dependency_func(param: annotation) -> None:
yield # pragma: nocover
app = FastAPI()
with pytest.raises(DependencyScopeConflict):
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
None, Depends(dependency_func, dependency_scope="lifespan")
],
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies(
dependency_style: DependencyStyle, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
async def lifespan_scoped_dependency(
param: Annotated[
int,
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"),
],
) -> AsyncGenerator[int, None]:
yield param
app = FastAPI()
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, Depends(lifespan_scoped_dependency)],
expected_value=1,
)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
expected_activation_times=1,
urls_and_responses=[("/test", 1)] * 2,
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize(
["dependency_style", "supports_teardown"],
[
(DependencyStyle.SYNC_FUNCTION, False),
(DependencyStyle.ASYNC_FUNCTION, False),
(DependencyStyle.SYNC_GENERATOR, True),
(DependencyStyle.ASYNC_GENERATOR, True),
],
)
def test_the_same_dependency_can_work_in_different_scopes(
dependency_style: DependencyStyle, supports_teardown, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
app = FastAPI()
create_endpoint_2_annotations(
router=app,
path="/test",
is_websocket=is_websocket,
annotation1=Annotated[
int,
Depends(dependency_factory.get_dependency(), dependency_scope="endpoint"),
],
annotation2=Annotated[
int,
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"),
],
)
if is_websocket:
get_response = use_websocket
else:
get_response = use_endpoint
assert dependency_factory.activation_times == 0
assert dependency_factory.deactivation_times == 0
with TestClient(app) as client:
assert dependency_factory.activation_times == 1
assert dependency_factory.deactivation_times == 0
assert get_response(client, "/test") == [2, 1]
assert dependency_factory.activation_times == 2
if supports_teardown:
if is_websocket:
# Websockets teardown might take some time after the test client
# has disconnected
sleep(0.1)
assert dependency_factory.deactivation_times == 1
else:
assert dependency_factory.deactivation_times == 0
assert get_response(client, "/test") == [3, 1]
assert dependency_factory.activation_times == 3
if supports_teardown:
if is_websocket:
# Websockets teardown might take some time after the test client
# has disconnected
sleep(0.1)
assert dependency_factory.deactivation_times == 2
else:
assert dependency_factory.deactivation_times == 0
assert dependency_factory.activation_times == 3
if supports_teardown:
assert dependency_factory.deactivation_times == 3
else:
assert dependency_factory.deactivation_times == 0
@pytest.mark.parametrize(
"lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"]
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans(
dependency_style: DependencyStyle,
is_websocket,
lifespan_style: Literal["lifespan_function", "lifespan_events"],
):
lifespan_started = False
lifespan_ended = False
if lifespan_style == "lifespan_generator":
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, int], None]:
nonlocal lifespan_started
nonlocal lifespan_ended
lifespan_started = True
yield
lifespan_ended = True
app = FastAPI(lifespan=lifespan)
elif lifespan_style == "events_decorator":
app = FastAPI()
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
@app.on_event("startup")
async def startup() -> None:
nonlocal lifespan_started
lifespan_started = True
@app.on_event("shutdown")
async def shutdown() -> None:
nonlocal lifespan_ended
lifespan_ended = True
else:
assert lifespan_style == "events_constructor"
async def startup() -> None:
nonlocal lifespan_started
lifespan_started = True
async def shutdown() -> None:
nonlocal lifespan_ended
lifespan_ended = True
app = FastAPI(on_startup=[startup], on_shutdown=[shutdown])
dependency_factory = DependencyFactory(dependency_style)
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
int,
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"),
],
expected_value=1,
)
expect_correct_amount_of_dependency_activations(
app=app,
dependency_factory=dependency_factory,
expected_activation_times=1,
urls_and_responses=[("/test", 1)] * 2,
is_websocket=is_websocket,
)
assert lifespan_started and lifespan_ended
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("depends_class", [Depends, Security])
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies(
depends_class, is_websocket
):
async def sub_dependency() -> None:
pass # pragma: nocover
async def dependency_func(
param: Annotated[None, depends_class(sub_dependency)],
) -> None:
pass # pragma: nocover
app = FastAPI()
with pytest.raises(DependencyScopeConflict):
create_endpoint_1_annotation(
router=app,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
None, Depends(dependency_func, dependency_scope="lifespan")
],
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_dependencies_must_provide_correct_dependency_scope(
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
with pytest.raises(
InvalidDependencyScope,
match=r'Dependency "value" of .* has an invalid scope: ' r'"incorrect"',
):
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[
None,
Depends(
dependency_factory.get_dependency(),
dependency_scope="incorrect",
use_cache=use_cache,
),
],
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_incorrect_dependency_scope(
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
# We intentionally change the dependency scope here to bypass the
# validation at the function level.
depends.dependency_scope = "asdad"
with pytest.raises(InvalidDependencyScope):
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app", "router"])
def test_endpoints_report_incorrect_dependency_scope_at_router_scope(
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR)
depends = Depends(dependency_factory.get_dependency(), dependency_scope="lifespan")
# We intentionally change the dependency scope here to bypass the
# validation at the function level.
depends.dependency_scope = "asdad"
if routing_style == "app":
app = FastAPI(dependencies=[depends])
router = app
else:
router = APIRouter(dependencies=[depends])
with pytest.raises(InvalidDependencyScope):
create_endpoint_0_annotations(
router=router,
path="/test",
is_websocket=is_websocket,
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_dependency(
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
expected_value=1,
)
if routing_style == "router_endpoint":
app.include_router(router)
with TestClient(app) as client:
dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"]
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {}
try:
with pytest.raises(UninitializedLifespanDependency):
if is_websocket:
with client.websocket_connect("/test"):
pass # pragma: nocover
else:
client.post("/test")
finally:
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = (
dependencies
)
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_endpoints_report_uninitialized_internal_lifespan(
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket
):
dependency_factory = DependencyFactory(dependency_style)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
expected_value=1,
)
if routing_style == "router_endpoint":
app.include_router(router)
with TestClient(app) as client:
internal_state = client.app_state["__fastapi__"]
del client.app_state["__fastapi__"]
try:
with pytest.raises(UninitializedLifespanDependency):
if is_websocket:
with client.websocket_connect("/test"):
pass # pragma: nocover
else:
client.post("/test")
finally:
client.app_state["__fastapi__"] = internal_state
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"])
@pytest.mark.parametrize("use_cache", [True, False])
@pytest.mark.parametrize("dependency_style", list(DependencyStyle))
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"])
def test_bad_lifespan_scoped_dependencies(
use_cache, dependency_style: DependencyStyle, routing_style, is_websocket
):
dependency_factory = DependencyFactory(dependency_style, should_error=True)
depends = Depends(
dependency_factory.get_dependency(),
dependency_scope="lifespan",
use_cache=use_cache,
)
app = FastAPI()
if routing_style == "app_endpoint":
router = app
else:
router = APIRouter()
create_endpoint_1_annotation(
router=router,
path="/test",
is_websocket=is_websocket,
annotation=Annotated[int, depends],
expected_value=1,
)
if routing_style == "router_endpoint":
app.include_router(router)
with pytest.raises(IntentionallyBadDependency) as exception_info:
with TestClient(app):
pass
assert exception_info.value.args == (1,)
def test_endpoint_dependant_backwards_compatibility():
dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR)
def endpoint(
dependency1: Annotated[int, Depends(dependency_factory.get_dependency())],
dependency2: Annotated[
int,
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"),
],
):
pass # pragma: nocover
dependant = get_endpoint_dependant(
path="/test",
call=endpoint,
name="endpoint",
)
assert dependant.dependencies == tuple(
dependant.lifespan_dependencies + dependant.endpoint_dependencies
)

193
tests/test_lifespan_scoped_dependencies/testing_utilities.py

@ -0,0 +1,193 @@
from enum import Enum
from typing import Any, AsyncGenerator, Generator, List, TypeVar, Union
from fastapi import APIRouter, FastAPI, WebSocket
from fastapi.testclient import TestClient
from typing_extensions import assert_never
T = TypeVar("T")
class DependencyStyle(str, Enum):
SYNC_FUNCTION = "sync_function"
ASYNC_FUNCTION = "async_function"
SYNC_GENERATOR = "sync_generator"
ASYNC_GENERATOR = "async_generator"
class IntentionallyBadDependency(Exception):
pass
class DependencyFactory:
def __init__(
self,
dependency_style: DependencyStyle,
*,
should_error: bool = False,
value_offset: int = 0,
):
self.activation_times = 0
self.deactivation_times = 0
self.dependency_style = dependency_style
self._should_error = should_error
self._value_offset = value_offset
def get_dependency(self):
if self.dependency_style == DependencyStyle.SYNC_FUNCTION:
return self._synchronous_function_dependency
if self.dependency_style == DependencyStyle.SYNC_GENERATOR:
return self._synchronous_generator_dependency
if self.dependency_style == DependencyStyle.ASYNC_FUNCTION:
return self._asynchronous_function_dependency
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR:
return self._asynchronous_generator_dependency
assert_never(self.dependency_style) # pragma: nocover
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]:
self.activation_times += 1
if self._should_error:
raise IntentionallyBadDependency(self.activation_times)
yield self.activation_times + self._value_offset
self.deactivation_times += 1
def _synchronous_generator_dependency(self) -> Generator[T, None, None]:
self.activation_times += 1
if self._should_error:
raise IntentionallyBadDependency(self.activation_times)
yield self.activation_times + self._value_offset
self.deactivation_times += 1
async def _asynchronous_function_dependency(self) -> T:
self.activation_times += 1
if self._should_error:
raise IntentionallyBadDependency(self.activation_times)
return self.activation_times + self._value_offset
def _synchronous_function_dependency(self) -> T:
self.activation_times += 1
if self._should_error:
raise IntentionallyBadDependency(self.activation_times)
return self.activation_times + self._value_offset
def use_endpoint(client: TestClient, url: str) -> Any:
response = client.post(url)
response.raise_for_status()
return response.json()
def use_websocket(client: TestClient, url: str) -> Any:
with client.websocket_connect(url) as connection:
return connection.receive_json()
def create_endpoint_0_annotations(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(websocket: WebSocket) -> None:
await websocket.accept()
await websocket.send_json(None)
else:
@router.post(path)
async def endpoint() -> None:
return None
def create_endpoint_1_annotation(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
annotation: Any,
expected_value: Any = None,
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(websocket: WebSocket, value: annotation) -> None:
if expected_value is not None:
assert value == expected_value
await websocket.accept()
await websocket.send_json(value)
else:
@router.post(path)
async def endpoint(value: annotation) -> Any:
if expected_value is not None:
assert value == expected_value
return value
def create_endpoint_2_annotations(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
annotation1: Any,
annotation2: Any,
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(
websocket: WebSocket,
value1: annotation1,
value2: annotation2,
) -> None:
await websocket.accept()
await websocket.send_json([value1, value2])
else:
@router.post(path)
async def endpoint(
value1: annotation1,
value2: annotation2,
) -> List[Any]:
return [value1, value2]
def create_endpoint_3_annotations(
*,
router: Union[APIRouter, FastAPI],
path: str,
is_websocket: bool,
annotation1: Any,
annotation2: Any,
annotation3: Any,
) -> None:
if is_websocket:
@router.websocket(path)
async def endpoint(
websocket: WebSocket,
value1: annotation1,
value2: annotation2,
value3: annotation3,
) -> None:
await websocket.accept()
await websocket.send_json([value1, value2, value3])
else:
@router.post(path)
async def endpoint(
value1: annotation1, value2: annotation2, value3: annotation3
) -> List[Any]:
return [value1, value2, value3]

35
tests/test_params_repr.py

@ -1,5 +1,6 @@
from typing import Any, List
import pytest
from dirty_equals import IsOneOf
from fastapi.params import Body, Cookie, Depends, Header, Param, Path, Query
@ -143,10 +144,30 @@ def test_body_repr_list():
assert repr(Body([])) == "Body([])"
def test_depends_repr():
assert repr(Depends()) == "Depends(NoneType)"
assert repr(Depends(get_user)) == "Depends(get_user)"
assert repr(Depends(use_cache=False)) == "Depends(NoneType, use_cache=False)"
assert (
repr(Depends(get_user, use_cache=False)) == "Depends(get_user, use_cache=False)"
)
@pytest.mark.parametrize(
["depends", "expected_repr"],
[
[Depends(), "Depends(NoneType)"],
[Depends(get_user), "Depends(get_user)"],
[Depends(use_cache=False), "Depends(NoneType, use_cache=False)"],
[Depends(get_user, use_cache=False), "Depends(get_user, use_cache=False)"],
[
Depends(dependency_scope="lifespan"),
'Depends(NoneType, dependency_scope="lifespan")',
],
[
Depends(get_user, dependency_scope="lifespan"),
'Depends(get_user, dependency_scope="lifespan")',
],
[
Depends(use_cache=False, dependency_scope="lifespan"),
'Depends(NoneType, use_cache=False, dependency_scope="lifespan")',
],
[
Depends(get_user, use_cache=False, dependency_scope="lifespan"),
'Depends(get_user, use_cache=False, dependency_scope="lifespan")',
],
],
)
def test_depends_repr(depends, expected_repr):
assert repr(depends) == expected_repr

8
tests/test_router_events.py

@ -199,6 +199,7 @@ def test_router_nested_lifespan_state_overriding_by_parent() -> None:
"app_specific": True,
"router_specific": True,
"overridden": "app",
"__fastapi__": {"lifespan_scoped_dependencies": {}},
}
@ -216,7 +217,7 @@ def test_merged_no_return_lifespans_return_none() -> None:
app.include_router(router)
with TestClient(app) as client:
assert not client.app_state
assert client.app_state == {"__fastapi__": {"lifespan_scoped_dependencies": {}}}
def test_merged_mixed_state_lifespans() -> None:
@ -239,4 +240,7 @@ def test_merged_mixed_state_lifespans() -> None:
app.include_router(router)
with TestClient(app) as client:
assert client.app_state == {"router": True}
assert client.app_state == {
"router": True,
"__fastapi__": {"lifespan_scoped_dependencies": {}},
}

65
tests/test_tutorial/test_dependencies/test_tutorial013a.py

@ -0,0 +1,65 @@
from typing import List
import pytest
from starlette.testclient import TestClient
from typing_extensions import Self
from docs_src.dependencies.tutorial013a import MyDatabaseConnection, app
class MockDatabaseConnection:
def __init__(self):
self.enter_count = 0
self.exit_count = 0
self.get_records_count = 0
async def __aenter__(self) -> Self:
self.enter_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aenter__(self)
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.exit_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb)
async def get_records(self, table_name: str) -> List[dict]:
self.get_records_count += 1
# Called for the sake of coverage.
await MyDatabaseConnection.get_records(self, table_name)
return []
@pytest.fixture
def database_connection_mock(monkeypatch) -> MockDatabaseConnection:
mock = MockDatabaseConnection()
monkeypatch.setattr(MyDatabaseConnection, "__new__", lambda *args, **kwargs: mock)
return mock
def test_dependency_usage(database_connection_mock):
assert database_connection_mock.enter_count == 0
assert database_connection_mock.exit_count == 0
with TestClient(app) as test_client:
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 0
response = test_client.get("/users")
assert response.status_code == 200
assert response.json() == []
assert database_connection_mock.get_records_count == 1
response = test_client.get("/items")
assert response.status_code == 200
assert response.json() == []
assert database_connection_mock.get_records_count == 2
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 0
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 1

70
tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py

@ -0,0 +1,70 @@
import sys
from typing import List
import pytest
from starlette.testclient import TestClient
from typing_extensions import Self
if sys.version_info >= (3, 9):
from docs_src.dependencies.tutorial013a_an_py39 import MyDatabaseConnection, app
from ...utils import needs_py39
class MockDatabaseConnection:
def __init__(self):
self.enter_count = 0
self.exit_count = 0
self.get_records_count = 0
async def __aenter__(self) -> Self:
self.enter_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aenter__(self)
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.exit_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb)
async def get_records(self, table_name: str) -> List[dict]:
self.get_records_count += 1
# Called for the sake of coverage.
await MyDatabaseConnection.get_records(self, table_name)
return []
@pytest.fixture
def database_connection_mock(monkeypatch) -> MockDatabaseConnection:
mock = MockDatabaseConnection()
monkeypatch.setattr(MyDatabaseConnection, "__new__", lambda *args, **kwargs: mock)
return mock
@needs_py39
def test_dependency_usage(database_connection_mock):
assert database_connection_mock.enter_count == 0
assert database_connection_mock.exit_count == 0
with TestClient(app) as test_client:
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 0
response = test_client.get("/users")
assert response.status_code == 200
assert response.json() == []
assert database_connection_mock.get_records_count == 1
response = test_client.get("/items")
assert response.status_code == 200
assert response.json() == []
assert database_connection_mock.get_records_count == 2
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 0
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 1

130
tests/test_tutorial/test_dependencies/test_tutorial013b.py

@ -0,0 +1,130 @@
from typing import List
import pytest
from starlette.testclient import TestClient
from typing_extensions import Self
from docs_src.dependencies.tutorial013b import MyDatabaseConnection, app
class MockDatabaseConnection:
def __init__(self):
self.enter_count = 0
self.exit_count = 0
self.get_records_count = 0
self.get_record_count = 0
async def __aenter__(self) -> Self:
self.enter_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aenter__(self)
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.exit_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb)
async def get_records(self, table_name: str) -> List[dict]:
self.get_records_count += 1
# Called for the sake of coverage.
await MyDatabaseConnection.get_records(self, table_name)
return []
async def get_record(self, table_name: str, record_id: str) -> dict:
self.get_record_count += 1
# Called for the sake of coverage.
await MyDatabaseConnection.get_record(self, table_name, record_id)
return {
"table_name": table_name,
"record_id": record_id,
}
@pytest.fixture
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]:
connections = []
def _get_new_connection_mock(*args, **kwargs):
mock = MockDatabaseConnection()
connections.append(mock)
return mock
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock)
return connections
def test_dependency_usage(database_connection_mocks):
assert len(database_connection_mocks) == 0
with TestClient(app) as test_client:
assert len(database_connection_mocks) == 3
for connection in database_connection_mocks:
assert connection.enter_count == 1
assert connection.exit_count == 0
assert connection.get_records_count == 0
assert connection.get_record_count == 0
response = test_client.get("/users")
assert response.status_code == 200
assert response.json() == []
users_connection = None
for connection in database_connection_mocks:
if connection.get_records_count == 1:
users_connection = connection
break
assert users_connection is not None, (
"No connection was found for users endpoint"
)
response = test_client.get("/groups")
assert response.status_code == 200
assert response.json() == []
groups_connection = None
for connection in database_connection_mocks:
if connection.get_records_count == 1 and connection is not users_connection:
groups_connection = connection
break
assert groups_connection is not None, (
"No connection was found for groups endpoint"
)
assert groups_connection.get_records_count == 1
items_connection = None
for connection in database_connection_mocks:
if connection.get_records_count == 0:
items_connection = connection
break
assert items_connection is not None, (
"No connection was found for items endpoint"
)
response = test_client.get("/items")
assert response.status_code == 200
assert response.json() == []
assert items_connection.get_records_count == 1
assert items_connection.get_record_count == 0
response = test_client.get("/items/asd")
assert response.status_code == 200
assert response.json() == {
"table_name": "items",
"record_id": "asd",
}
assert items_connection.get_records_count == 1
assert items_connection.get_record_count == 1
for connection in database_connection_mocks:
assert connection.enter_count == 1
assert connection.exit_count == 0
for connection in database_connection_mocks:
assert connection.enter_count == 1
assert connection.exit_count == 1

135
tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py

@ -0,0 +1,135 @@
import sys
from typing import List
import pytest
from starlette.testclient import TestClient
from typing_extensions import Self
if sys.version_info >= (3, 9):
from docs_src.dependencies.tutorial013b_an_py39 import MyDatabaseConnection, app
from ...utils import needs_py39
class MockDatabaseConnection:
def __init__(self):
self.enter_count = 0
self.exit_count = 0
self.get_records_count = 0
self.get_record_count = 0
async def __aenter__(self) -> Self:
self.enter_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aenter__(self)
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.exit_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb)
async def get_records(self, table_name: str) -> List[dict]:
self.get_records_count += 1
# Called for the sake of coverage.
await MyDatabaseConnection.get_records(self, table_name)
return []
async def get_record(self, table_name: str, record_id: str) -> dict:
self.get_record_count += 1
# Called for the sake of coverage.
await MyDatabaseConnection.get_record(self, table_name, record_id)
return {
"table_name": table_name,
"record_id": record_id,
}
@pytest.fixture
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]:
connections = []
def _get_new_connection_mock(*args, **kwargs):
mock = MockDatabaseConnection()
connections.append(mock)
return mock
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock)
return connections
@needs_py39
def test_dependency_usage(database_connection_mocks):
assert len(database_connection_mocks) == 0
with TestClient(app) as test_client:
assert len(database_connection_mocks) == 3
for connection in database_connection_mocks:
assert connection.enter_count == 1
assert connection.exit_count == 0
assert connection.get_records_count == 0
assert connection.get_record_count == 0
response = test_client.get("/users")
assert response.status_code == 200
assert response.json() == []
users_connection = None
for connection in database_connection_mocks:
if connection.get_records_count == 1:
users_connection = connection
break
assert users_connection is not None, (
"No connection was found for users endpoint"
)
response = test_client.get("/groups")
assert response.status_code == 200
assert response.json() == []
groups_connection = None
for connection in database_connection_mocks:
if connection.get_records_count == 1 and connection is not users_connection:
groups_connection = connection
break
assert groups_connection is not None, (
"No connection was found for groups endpoint"
)
assert groups_connection.get_records_count == 1
items_connection = None
for connection in database_connection_mocks:
if connection.get_records_count == 0:
items_connection = connection
break
assert items_connection is not None, (
"No connection was found for items endpoint"
)
response = test_client.get("/items")
assert response.status_code == 200
assert response.json() == []
assert items_connection.get_records_count == 1
assert items_connection.get_record_count == 0
response = test_client.get("/items/asd")
assert response.status_code == 200
assert response.json() == {
"table_name": "items",
"record_id": "asd",
}
assert items_connection.get_records_count == 1
assert items_connection.get_record_count == 1
for connection in database_connection_mocks:
assert connection.enter_count == 1
assert connection.exit_count == 0
for connection in database_connection_mocks:
assert connection.enter_count == 1
assert connection.exit_count == 1

78
tests/test_tutorial/test_dependencies/test_tutorial013c.py

@ -0,0 +1,78 @@
from typing import List
import pytest
from starlette.testclient import TestClient
from typing_extensions import Self
from docs_src.dependencies.tutorial013c import MyDatabaseConnection, app
class MockDatabaseConnection:
def __init__(self, url: str):
self.url = url
self.enter_count = 0
self.exit_count = 0
self.get_record_count = 0
async def __aenter__(self) -> Self:
self.enter_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aenter__(self)
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.exit_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb)
async def get_record(self, table_name: str, record_id: str) -> dict:
self.get_record_count += 1
# Called for the sake of coverage.
await MyDatabaseConnection.get_record(self, table_name, record_id)
return {
"table_name": table_name,
"record_id": record_id,
}
@pytest.fixture
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]:
connections = []
def _get_new_connection_mock(cls, url):
mock = MockDatabaseConnection(url)
connections.append(mock)
return mock
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock)
return connections
def test_dependency_usage(database_connection_mocks):
assert len(database_connection_mocks) == 0
with TestClient(app) as test_client:
assert len(database_connection_mocks) == 1
[database_connection_mock] = database_connection_mocks
assert database_connection_mock.url == "sqlite:///database.db"
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 0
assert database_connection_mock.get_record_count == 0
response = test_client.get("/users/user")
assert response.status_code == 200
assert response.json() == {
"table_name": "users",
"record_id": "user",
}
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 0
assert database_connection_mock.get_record_count == 1
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 1
assert database_connection_mock.get_record_count == 1
assert len(database_connection_mocks) == 1

83
tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py

@ -0,0 +1,83 @@
import sys
from typing import List
import pytest
from starlette.testclient import TestClient
from typing_extensions import Self
if sys.version_info >= (3, 9):
from docs_src.dependencies.tutorial013c_an_py39 import MyDatabaseConnection, app
from ...utils import needs_py39
class MockDatabaseConnection:
def __init__(self, url: str):
self.url = url
self.enter_count = 0
self.exit_count = 0
self.get_record_count = 0
async def __aenter__(self) -> Self:
self.enter_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aenter__(self)
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.exit_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb)
async def get_record(self, table_name: str, record_id: str) -> dict:
self.get_record_count += 1
# Called for the sake of coverage.
await MyDatabaseConnection.get_record(self, table_name, record_id)
return {
"table_name": table_name,
"record_id": record_id,
}
@pytest.fixture
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]:
connections = []
def _get_new_connection_mock(cls, url):
mock = MockDatabaseConnection(url)
connections.append(mock)
return mock
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock)
return connections
@needs_py39
def test_dependency_usage(database_connection_mocks):
assert len(database_connection_mocks) == 0
with TestClient(app) as test_client:
assert len(database_connection_mocks) == 1
[database_connection_mock] = database_connection_mocks
assert database_connection_mock.url == "sqlite:///database.db"
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 0
assert database_connection_mock.get_record_count == 0
response = test_client.get("/users/user")
assert response.status_code == 200
assert response.json() == {
"table_name": "users",
"record_id": "user",
}
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 0
assert database_connection_mock.get_record_count == 1
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 1
assert database_connection_mock.get_record_count == 1
assert len(database_connection_mocks) == 1

76
tests/test_tutorial/test_dependencies/test_tutorial013d.py

@ -0,0 +1,76 @@
from typing import List
import pytest
from starlette.testclient import TestClient
from typing_extensions import Self
from docs_src.dependencies.tutorial013d import MyDatabaseConnection, app
class MockDatabaseConnection:
def __init__(self):
self.enter_count = 0
self.exit_count = 0
self.get_record_count = 0
async def __aenter__(self) -> Self:
self.enter_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aenter__(self)
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.exit_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb)
async def get_record(self, table_name: str, record_id: str) -> dict:
self.get_record_count += 1
# Called for the sake of coverage.
await MyDatabaseConnection.get_record(self, table_name, record_id)
return {
"table_name": table_name,
"record_id": record_id,
}
@pytest.fixture
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]:
connections = []
def _get_new_connection_mock(*args, **kwargs):
mock = MockDatabaseConnection()
connections.append(mock)
return mock
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock)
return connections
def test_dependency_usage(database_connection_mocks):
assert len(database_connection_mocks) == 0
with TestClient(app) as test_client:
assert len(database_connection_mocks) == 1
[database_connection_mock] = database_connection_mocks
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 0
assert database_connection_mock.get_record_count == 0
response = test_client.get("/users/user")
assert response.status_code == 200
assert response.json() == {
"table_name": "users",
"record_id": "user",
}
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 0
assert database_connection_mock.get_record_count == 1
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 1
assert database_connection_mock.get_record_count == 1
assert len(database_connection_mocks) == 1

81
tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py

@ -0,0 +1,81 @@
import sys
from typing import List
import pytest
from starlette.testclient import TestClient
from typing_extensions import Self
if sys.version_info >= (3, 9):
from docs_src.dependencies.tutorial013d_an_py39 import MyDatabaseConnection, app
from ...utils import needs_py39
class MockDatabaseConnection:
def __init__(self):
self.enter_count = 0
self.exit_count = 0
self.get_record_count = 0
async def __aenter__(self) -> Self:
self.enter_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aenter__(self)
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.exit_count += 1
# Called for the sake of coverage.
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb)
async def get_record(self, table_name: str, record_id: str) -> dict:
self.get_record_count += 1
# Called for the sake of coverage.
await MyDatabaseConnection.get_record(self, table_name, record_id)
return {
"table_name": table_name,
"record_id": record_id,
}
@pytest.fixture
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]:
connections = []
def _get_new_connection_mock(*args, **kwargs):
mock = MockDatabaseConnection()
connections.append(mock)
return mock
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock)
return connections
@needs_py39
def test_dependency_usage(database_connection_mocks):
assert len(database_connection_mocks) == 0
with TestClient(app) as test_client:
assert len(database_connection_mocks) == 1
[database_connection_mock] = database_connection_mocks
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 0
assert database_connection_mock.get_record_count == 0
response = test_client.get("/users/user")
assert response.status_code == 200
assert response.json() == {
"table_name": "users",
"record_id": "user",
}
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 0
assert database_connection_mock.get_record_count == 1
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 1
assert database_connection_mock.get_record_count == 1
assert len(database_connection_mocks) == 1
Loading…
Cancel
Save