committed by
GitHub
33 changed files with 3513 additions and 92 deletions
@ -0,0 +1,111 @@ |
|||
# Lifespan Scoped Dependencies |
|||
|
|||
## Intro |
|||
|
|||
So far we've used dependencies which are "endpoint scoped". Meaning, they are |
|||
called again and again for every incoming request to the endpoint. However, |
|||
this is not always ideal: |
|||
|
|||
* Sometimes dependencies have a large setup/teardown time. Running it for every request will result in bad performance. |
|||
* Sometimes dependencies need to have their values shared throughout the lifespan |
|||
of the application between multiple requests. |
|||
|
|||
|
|||
An example of this would be a connection to a database. Databases are typically |
|||
less efficient when working with lots of connections and would prefer that |
|||
clients would create a single connection for their operations. |
|||
|
|||
For such cases can be solved by using "lifespan scoped dependencies". |
|||
|
|||
|
|||
## What is a lifespan scoped dependency? |
|||
Lifespan scoped dependencies work similarly to the (endpoint scoped) |
|||
dependencies we've worked with so far. However, unlike endpoint scoped |
|||
dependencies, lifespan scoped dependencies are called once and only |
|||
once in the application's lifespan: |
|||
|
|||
* During the application startup process, all lifespan scoped dependencies will |
|||
be called. |
|||
* Their returned value will be shared across all requests to the application. |
|||
* During the application's shutdown process, all lifespan scoped dependencies |
|||
will be gracefully teared down. |
|||
|
|||
|
|||
## Create a lifespan scoped dependency |
|||
|
|||
You may declare a dependency as a lifespan scoped dependency by passing |
|||
`dependency_scope="lifespan"` to the `Depends` function: |
|||
|
|||
{* ../../docs_src/dependencies/tutorial013a_an_py39.py *} |
|||
|
|||
/// tip |
|||
|
|||
In the example above we saved the annotation to a separate variable, and then |
|||
reused it in our endpoints. This is not a requirement, we could also declare |
|||
the exact same annotation in both endpoints. However, it is recommended that you |
|||
do save the annotation to a variable so you won't accidentally forget to pass |
|||
`dependency_scope="lifespan"` to some of the endpoints (Causing the endpoint |
|||
to create a new database connection for every request). |
|||
|
|||
/// |
|||
|
|||
In this example, the `get_database_connection` dependency will be executed once, |
|||
during the application's startup. **FastAPI** will internally save the resulting |
|||
connection object, and whenever the `read_users` and `read_items` endpoints are |
|||
called, they will be using the previously saved connection. Once the application |
|||
shuts down, **FastAPI** will make sure to gracefully close the connection object. |
|||
|
|||
## The `use_cache` argument |
|||
|
|||
The `use_cache` argument works similarly to the way it worked with endpoint |
|||
scoped dependencies. Meaning as **FastAPI** gathers lifespan scoped dependencies, it |
|||
will cache dependencies it already encountered before. However, you can disable |
|||
this behavior by passing `use_cache=False` to `Depends`: |
|||
|
|||
{* ../../docs_src/dependencies/tutorial013b_an_py39.py *} |
|||
|
|||
In this example, the `read_users` and `read_groups` endpoints are using |
|||
`use_cache=False` whereas the `read_items` and `read_item` are using |
|||
`use_cache=True`. |
|||
That means that we'll have a total of 3 connections created |
|||
for the duration of the application's lifespan: |
|||
|
|||
* One connection will be shared across all requests for the `read_items` and `read_item` endpoints. |
|||
* A second connection will be shared across all requests for the `read_users` endpoint. |
|||
* A third and final connection will be shared across all requests for the `read_groups` endpoint. |
|||
|
|||
|
|||
## Lifespan Scoped Sub-Dependencies |
|||
Just like with endpoint scoped dependencies, lifespan scoped dependencies may |
|||
use other lifespan scoped sub-dependencies themselves: |
|||
|
|||
{* ../../docs_src/dependencies/tutorial013c_an_py39.py *} |
|||
|
|||
Endpoint scoped dependencies may use lifespan scoped sub dependencies as well: |
|||
|
|||
{* ../../docs_src/dependencies/tutorial013d_an_py39.py *} |
|||
|
|||
/// note |
|||
|
|||
You can pass `dependency_scope="endpoint"` if you wish to explicitly specify |
|||
that a dependency is endpoint scoped. It will work the same as not specifying |
|||
a dependency scope at all. |
|||
|
|||
/// |
|||
|
|||
As you can see, regardless of the scope, dependencies can use lifespan scoped |
|||
sub-dependencies. |
|||
|
|||
## Dependency Scope Conflicts |
|||
By definition, lifespan scoped dependencies are being setup in the application's |
|||
startup process, before any request is ever being made to any endpoint. |
|||
Therefore, it is not possible for a lifespan scoped dependency to use any |
|||
parameters that require the scope of an endpoint. |
|||
|
|||
That includes but not limited to: |
|||
|
|||
* Parts of the request (like `Body`, `Query` and `Path`) |
|||
* The request/response objects themselves (like `Request`, `Response` and `WebSocket`) |
|||
* Endpoint scoped sub-dependencies. |
|||
|
|||
Defining a dependency with such parameters will raise an `InvalidDependencyScope` error. |
@ -0,0 +1,44 @@ |
|||
from typing import List |
|||
|
|||
from fastapi import Depends, FastAPI |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_records(self, table_name: str) -> List[dict]: |
|||
pass |
|||
|
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_database_connection(): |
|||
async with MyDatabaseConnection() as connection: |
|||
yield connection |
|||
|
|||
|
|||
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
|||
|
|||
|
|||
@app.get("/users/") |
|||
async def read_users( |
|||
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
|||
): |
|||
return await database_connection.get_records("users") |
|||
|
|||
|
|||
@app.get("/items/") |
|||
async def read_items( |
|||
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
|||
): |
|||
return await database_connection.get_records("items") |
@ -0,0 +1,42 @@ |
|||
from typing import Annotated |
|||
|
|||
from fastapi import Depends, FastAPI |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_records(self, table_name: str) -> list[dict]: |
|||
pass |
|||
|
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_database_connection(): |
|||
async with MyDatabaseConnection() as connection: |
|||
yield connection |
|||
|
|||
|
|||
GlobalDatabaseConnection = Annotated[ |
|||
MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan") |
|||
] |
|||
|
|||
|
|||
@app.get("/users/") |
|||
async def read_users(database_connection: GlobalDatabaseConnection): |
|||
return await database_connection.get_records("users") |
|||
|
|||
|
|||
@app.get("/items/") |
|||
async def read_items(database_connection: GlobalDatabaseConnection): |
|||
return await database_connection.get_records("items") |
@ -0,0 +1,65 @@ |
|||
from typing import List |
|||
|
|||
from fastapi import Depends, FastAPI, Path |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_records(self, table_name: str) -> List[dict]: |
|||
pass |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
pass |
|||
|
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_database_connection(): |
|||
async with MyDatabaseConnection() as connection: |
|||
yield connection |
|||
|
|||
|
|||
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
|||
DedicatedDatabaseConnection = Depends( |
|||
get_database_connection, dependency_scope="lifespan", use_cache=False |
|||
) |
|||
|
|||
|
|||
@app.get("/groups/") |
|||
async def read_groups( |
|||
database_connection: MyDatabaseConnection = DedicatedDatabaseConnection, |
|||
): |
|||
return await database_connection.get_records("groups") |
|||
|
|||
|
|||
@app.get("/users/") |
|||
async def read_users( |
|||
database_connection: MyDatabaseConnection = DedicatedDatabaseConnection, |
|||
): |
|||
return await database_connection.get_records("users") |
|||
|
|||
|
|||
@app.get("/items/") |
|||
async def read_items( |
|||
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
|||
): |
|||
return await database_connection.get_records("items") |
|||
|
|||
|
|||
@app.get("/items/{item_id}") |
|||
async def read_item( |
|||
item_id: str = Path(), |
|||
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
|||
): |
|||
return await database_connection.get_record("items", item_id) |
@ -0,0 +1,61 @@ |
|||
from typing import Annotated |
|||
|
|||
from fastapi import Depends, FastAPI, Path |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_records(self, table_name: str) -> list[dict]: |
|||
pass |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
pass |
|||
|
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_database_connection(): |
|||
async with MyDatabaseConnection() as connection: |
|||
yield connection |
|||
|
|||
|
|||
GlobalDatabaseConnection = Annotated[ |
|||
MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan") |
|||
] |
|||
DedicatedDatabaseConnection = Annotated[ |
|||
MyDatabaseConnection, |
|||
Depends(get_database_connection, dependency_scope="lifespan", use_cache=False), |
|||
] |
|||
|
|||
|
|||
@app.get("/groups/") |
|||
async def read_groups(database_connection: DedicatedDatabaseConnection): |
|||
return await database_connection.get_records("groups") |
|||
|
|||
|
|||
@app.get("/users/") |
|||
async def read_users(database_connection: DedicatedDatabaseConnection): |
|||
return await database_connection.get_records("users") |
|||
|
|||
|
|||
@app.get("/items/") |
|||
async def read_items(database_connection: GlobalDatabaseConnection): |
|||
return await database_connection.get_records("items") |
|||
|
|||
|
|||
@app.get("/items/{item_id}") |
|||
async def read_item( |
|||
database_connection: GlobalDatabaseConnection, item_id: Annotated[str, Path()] |
|||
): |
|||
return await database_connection.get_record("items", item_id) |
@ -0,0 +1,50 @@ |
|||
from dataclasses import dataclass |
|||
|
|||
from fastapi import Depends, FastAPI, Path |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
@dataclass |
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
|
|||
connection_string: str |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
pass |
|||
|
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_configuration() -> dict: |
|||
return { |
|||
"database_url": "sqlite:///database.db", |
|||
} |
|||
|
|||
|
|||
GlobalConfiguration = Depends(get_configuration, dependency_scope="lifespan") |
|||
|
|||
|
|||
async def get_database_connection(configuration: dict = GlobalConfiguration): |
|||
async with MyDatabaseConnection(configuration["database_url"]) as connection: |
|||
yield connection |
|||
|
|||
|
|||
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
|||
|
|||
|
|||
@app.get("/users/{user_id}") |
|||
async def read_user( |
|||
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
|||
user_id: str = Path(), |
|||
): |
|||
return await database_connection.get_record("users", user_id) |
@ -0,0 +1,55 @@ |
|||
from dataclasses import dataclass |
|||
from typing import Annotated |
|||
|
|||
from fastapi import Depends, FastAPI, Path |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
@dataclass |
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
|
|||
connection_string: str |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
pass |
|||
|
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_configuration() -> dict: |
|||
return { |
|||
"database_url": "sqlite:///database.db", |
|||
} |
|||
|
|||
|
|||
GlobalConfiguration = Annotated[ |
|||
dict, Depends(get_configuration, dependency_scope="lifespan") |
|||
] |
|||
|
|||
|
|||
async def get_database_connection(configuration: GlobalConfiguration): |
|||
async with MyDatabaseConnection(configuration["database_url"]) as connection: |
|||
yield connection |
|||
|
|||
|
|||
GlobalDatabaseConnection = Annotated[ |
|||
get_database_connection, |
|||
Depends(get_database_connection, dependency_scope="lifespan"), |
|||
] |
|||
|
|||
|
|||
@app.get("/users/{user_id}") |
|||
async def read_user( |
|||
database_connection: GlobalDatabaseConnection, user_id: Annotated[str, Path()] |
|||
): |
|||
return await database_connection.get_record("users", user_id) |
@ -0,0 +1,40 @@ |
|||
from fastapi import Depends, FastAPI, Path |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
pass |
|||
|
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_database_connection(): |
|||
async with MyDatabaseConnection() as connection: |
|||
yield connection |
|||
|
|||
|
|||
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
|||
|
|||
|
|||
async def get_user_record( |
|||
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
|||
user_id: str = Path(), |
|||
) -> dict: |
|||
return await database_connection.get_record("users", user_id) |
|||
|
|||
|
|||
@app.get("/users/{user_id}") |
|||
async def read_user(user_record: dict = Depends(get_user_record)): |
|||
return user_record |
@ -0,0 +1,43 @@ |
|||
from typing import Annotated |
|||
|
|||
from fastapi import Depends, FastAPI, Path |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
pass |
|||
|
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_database_connection(): |
|||
async with MyDatabaseConnection() as connection: |
|||
yield connection |
|||
|
|||
|
|||
GlobalDatabaseConnection = Annotated[ |
|||
MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan") |
|||
] |
|||
|
|||
|
|||
async def get_user_record( |
|||
database_connection: GlobalDatabaseConnection, user_id: Annotated[str, Path()] |
|||
) -> dict: |
|||
return await database_connection.get_record("users", user_id) |
|||
|
|||
|
|||
@app.get("/users/{user_id}") |
|||
async def read_user(user_record: Annotated[dict, Depends(get_user_record)]): |
|||
return user_record |
@ -0,0 +1,44 @@ |
|||
from __future__ import annotations |
|||
|
|||
from contextlib import AsyncExitStack |
|||
from typing import TYPE_CHECKING, Any, Callable, Dict, List |
|||
|
|||
from fastapi.dependencies.models import LifespanDependant, LifespanDependantCacheKey |
|||
from fastapi.dependencies.utils import solve_lifespan_dependant |
|||
from fastapi.routing import APIRoute, APIWebSocketRoute |
|||
|
|||
if TYPE_CHECKING: # pragma: nocover |
|||
from fastapi import FastAPI |
|||
|
|||
|
|||
def _get_lifespan_dependants(app: FastAPI) -> List[LifespanDependant]: |
|||
lifespan_dependants_cache: Dict[LifespanDependantCacheKey, LifespanDependant] = {} |
|||
for route in app.router.routes: |
|||
if not isinstance(route, (APIWebSocketRoute, APIRoute)): |
|||
continue |
|||
|
|||
for sub_dependant in route.lifespan_dependencies: |
|||
if sub_dependant.cache_key in lifespan_dependants_cache: |
|||
continue |
|||
|
|||
lifespan_dependants_cache[sub_dependant.cache_key] = sub_dependant |
|||
|
|||
return list(lifespan_dependants_cache.values()) |
|||
|
|||
|
|||
async def resolve_lifespan_dependants( |
|||
*, app: FastAPI, async_exit_stack: AsyncExitStack |
|||
) -> Dict[LifespanDependantCacheKey, Callable[..., Any]]: |
|||
lifespan_dependants = _get_lifespan_dependants(app) |
|||
dependency_cache: Dict[LifespanDependantCacheKey, Callable[..., Any]] = {} |
|||
for lifespan_dependant in lifespan_dependants: |
|||
solved_dependency = await solve_lifespan_dependant( |
|||
dependant=lifespan_dependant, |
|||
dependency_overrides_provider=app, |
|||
dependency_cache=dependency_cache, |
|||
async_exit_stack=async_exit_stack, |
|||
) |
|||
|
|||
dependency_cache.update(solved_dependency.dependency_cache) |
|||
|
|||
return dependency_cache |
@ -0,0 +1,624 @@ |
|||
from typing import Any, AsyncGenerator, List, Tuple |
|||
|
|||
import pytest |
|||
from fastapi import ( |
|||
APIRouter, |
|||
BackgroundTasks, |
|||
Body, |
|||
Cookie, |
|||
Depends, |
|||
FastAPI, |
|||
File, |
|||
Form, |
|||
Header, |
|||
Path, |
|||
Query, |
|||
Request, |
|||
WebSocket, |
|||
) |
|||
from fastapi.exceptions import DependencyScopeConflict |
|||
from fastapi.params import Security |
|||
from fastapi.security import SecurityScopes |
|||
from fastapi.testclient import TestClient |
|||
from typing_extensions import Annotated, Literal |
|||
|
|||
from tests.test_lifespan_scoped_dependencies.testing_utilities import ( |
|||
DependencyFactory, |
|||
DependencyStyle, |
|||
IntentionallyBadDependency, |
|||
create_endpoint_0_annotations, |
|||
create_endpoint_1_annotation, |
|||
create_endpoint_3_annotations, |
|||
use_endpoint, |
|||
use_websocket, |
|||
) |
|||
|
|||
|
|||
def expect_correct_amount_of_dependency_activations( |
|||
*, |
|||
app: FastAPI, |
|||
dependency_factory: DependencyFactory, |
|||
override_dependency_factory: DependencyFactory, |
|||
urls_and_responses: List[Tuple[str, Any]], |
|||
expected_activation_times: int, |
|||
is_websocket: bool, |
|||
) -> None: |
|||
assert dependency_factory.activation_times == 0 |
|||
assert dependency_factory.deactivation_times == 0 |
|||
assert override_dependency_factory.activation_times == 0 |
|||
assert override_dependency_factory.deactivation_times == 0 |
|||
|
|||
with TestClient(app) as client: |
|||
assert dependency_factory.activation_times == 0 |
|||
assert dependency_factory.deactivation_times == 0 |
|||
assert override_dependency_factory.activation_times == expected_activation_times |
|||
assert override_dependency_factory.deactivation_times == 0 |
|||
|
|||
for url, expected_response in urls_and_responses: |
|||
if is_websocket: |
|||
response = use_websocket(client, url) |
|||
else: |
|||
response = use_endpoint(client, url) |
|||
|
|||
assert response == expected_response |
|||
|
|||
assert dependency_factory.activation_times == 0 |
|||
assert dependency_factory.deactivation_times == 0 |
|||
assert ( |
|||
override_dependency_factory.activation_times |
|||
== expected_activation_times |
|||
) |
|||
assert override_dependency_factory.deactivation_times == 0 |
|||
|
|||
assert dependency_factory.activation_times == 0 |
|||
assert override_dependency_factory.activation_times == expected_activation_times |
|||
if dependency_factory.dependency_style not in ( |
|||
DependencyStyle.SYNC_FUNCTION, |
|||
DependencyStyle.ASYNC_FUNCTION, |
|||
): |
|||
assert dependency_factory.deactivation_times == 0 |
|||
assert ( |
|||
override_dependency_factory.deactivation_times == expected_activation_times |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoint_dependencies( |
|||
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[ |
|||
None, |
|||
Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
), |
|||
], |
|||
expected_value=11, |
|||
) |
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
|||
override_dependency_factory.get_dependency() |
|||
) |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
urls_and_responses=[("/test", 11)] * 2, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("dependency_duplication", [1, 2]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_router_dependencies( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
dependency_duplication, |
|||
is_websocket, |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
if routing_style == "app": |
|||
app = FastAPI(dependencies=[depends] * dependency_duplication) |
|||
|
|||
create_endpoint_0_annotations( |
|||
router=app, path="/test", is_websocket=is_websocket |
|||
) |
|||
else: |
|||
app = FastAPI() |
|||
router = APIRouter(dependencies=[depends] * dependency_duplication) |
|||
|
|||
create_endpoint_0_annotations( |
|||
router=router, path="/test", is_websocket=is_websocket |
|||
) |
|||
|
|||
app.include_router(router) |
|||
|
|||
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
|||
override_dependency_factory.get_dependency() |
|||
) |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
urls_and_responses=[("/test", None)] * 2, |
|||
expected_activation_times=1 if use_cache else dependency_duplication, |
|||
is_websocket=is_websocket, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) |
|||
def test_dependency_cache_in_same_dependency( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
main_dependency_scope: Literal["endpoint", "lifespan"], |
|||
is_websocket, |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def dependency( |
|||
sub_dependency1: Annotated[int, depends], |
|||
sub_dependency2: Annotated[int, depends], |
|||
) -> List[int]: |
|||
return [sub_dependency1, sub_dependency2] |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[ |
|||
List[int], |
|||
Depends( |
|||
dependency, |
|||
use_cache=use_cache, |
|||
dependency_scope=main_dependency_scope, |
|||
), |
|||
], |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
|||
override_dependency_factory.get_dependency() |
|||
) |
|||
|
|||
if use_cache: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [11, 11]), |
|||
("/test", [11, 11]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket, |
|||
) |
|||
else: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [11, 12]), |
|||
("/test", [11, 12]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
expected_activation_times=2, |
|||
is_websocket=is_websocket, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_dependency_cache_in_same_endpoint( |
|||
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
|||
return dependency3 |
|||
|
|||
create_endpoint_3_annotations( |
|||
router=router, |
|||
path="/test1", |
|||
is_websocket=is_websocket, |
|||
annotation1=Annotated[int, depends], |
|||
annotation2=Annotated[int, depends], |
|||
annotation3=Annotated[int, Depends(endpoint_dependency)], |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
|||
override_dependency_factory.get_dependency() |
|||
) |
|||
|
|||
if use_cache: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [11, 11, 11]), |
|||
("/test1", [11, 11, 11]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket, |
|||
) |
|||
else: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [11, 12, 13]), |
|||
("/test1", [11, 12, 13]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
expected_activation_times=3, |
|||
is_websocket=is_websocket, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_dependency_cache_in_different_endpoints( |
|||
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
|||
return dependency3 |
|||
|
|||
create_endpoint_3_annotations( |
|||
router=router, |
|||
path="/test1", |
|||
is_websocket=is_websocket, |
|||
annotation1=Annotated[int, depends], |
|||
annotation2=Annotated[int, depends], |
|||
annotation3=Annotated[int, Depends(endpoint_dependency)], |
|||
) |
|||
|
|||
create_endpoint_3_annotations( |
|||
router=router, |
|||
path="/test2", |
|||
is_websocket=is_websocket, |
|||
annotation1=Annotated[int, depends], |
|||
annotation2=Annotated[int, depends], |
|||
annotation3=Annotated[int, Depends(endpoint_dependency)], |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
|||
override_dependency_factory.get_dependency() |
|||
) |
|||
|
|||
if use_cache: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [11, 11, 11]), |
|||
("/test2", [11, 11, 11]), |
|||
("/test1", [11, 11, 11]), |
|||
("/test2", [11, 11, 11]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket, |
|||
) |
|||
else: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [11, 12, 13]), |
|||
("/test2", [14, 15, 13]), |
|||
("/test1", [11, 12, 13]), |
|||
("/test2", [14, 15, 13]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
expected_activation_times=5, |
|||
is_websocket=is_websocket, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_no_cached_dependency( |
|||
dependency_style: DependencyStyle, routing_style, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=False, |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, depends], |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
|||
override_dependency_factory.get_dependency() |
|||
) |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
urls_and_responses=[("/test", 11)] * 2, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize( |
|||
"annotation", |
|||
[ |
|||
Annotated[str, Path()], |
|||
Annotated[str, Body()], |
|||
Annotated[str, Query()], |
|||
Annotated[str, Header()], |
|||
SecurityScopes, |
|||
Annotated[str, Cookie()], |
|||
Annotated[str, Form()], |
|||
Annotated[str, File()], |
|||
BackgroundTasks, |
|||
Request, |
|||
WebSocket, |
|||
], |
|||
) |
|||
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( |
|||
annotation, is_websocket |
|||
): |
|||
async def dependency_func() -> None: |
|||
yield # pragma: nocover |
|||
|
|||
async def override_dependency_func(param: annotation) -> None: |
|||
yield # pragma: nocover |
|||
|
|||
app = FastAPI() |
|||
app.dependency_overrides[dependency_func] = override_dependency_func |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[ |
|||
None, Depends(dependency_func, dependency_scope="lifespan") |
|||
], |
|||
) |
|||
|
|||
with pytest.raises(DependencyScopeConflict): |
|||
with TestClient(app): |
|||
pass |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies( |
|||
dependency_style: DependencyStyle, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) |
|||
|
|||
async def lifespan_scoped_dependency( |
|||
param: Annotated[ |
|||
int, |
|||
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), |
|||
], |
|||
) -> AsyncGenerator[int, None]: |
|||
yield param |
|||
|
|||
app = FastAPI() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[ |
|||
int, Depends(lifespan_scoped_dependency, dependency_scope="lifespan") |
|||
], |
|||
) |
|||
|
|||
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
|||
override_dependency_factory.get_dependency() |
|||
) |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
override_dependency_factory=override_dependency_factory, |
|||
expected_activation_times=1, |
|||
urls_and_responses=[("/test", 11)] * 2, |
|||
is_websocket=is_websocket, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("depends_class", [Depends, Security]) |
|||
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( |
|||
depends_class, is_websocket |
|||
): |
|||
async def sub_dependency() -> None: |
|||
pass # pragma: nocover |
|||
|
|||
async def dependency_func() -> None: |
|||
yield # pragma: nocover |
|||
|
|||
async def override_dependency_func( |
|||
param: Annotated[None, depends_class(sub_dependency)], |
|||
) -> None: |
|||
yield # pragma: nocover |
|||
|
|||
app = FastAPI() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[ |
|||
None, Depends(dependency_func, dependency_scope="lifespan") |
|||
], |
|||
) |
|||
|
|||
app.dependency_overrides[dependency_func] = override_dependency_func |
|||
|
|||
with pytest.raises(DependencyScopeConflict): |
|||
with TestClient(app): |
|||
pass |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_bad_override_lifespan_scoped_dependencies( |
|||
use_cache, dependency_style: DependencyStyle, routing_style, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
override_dependency_factory = DependencyFactory(dependency_style, should_error=True) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, depends], |
|||
) |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
|||
override_dependency_factory.get_dependency() |
|||
) |
|||
|
|||
with pytest.raises(IntentionallyBadDependency) as exception_info: |
|||
with TestClient(app): |
|||
pass |
|||
|
|||
assert exception_info.value.args == (1,) |
@ -0,0 +1,920 @@ |
|||
import warnings |
|||
from contextlib import asynccontextmanager |
|||
from time import sleep |
|||
from typing import Any, AsyncGenerator, Dict, List, Tuple |
|||
|
|||
import pytest |
|||
from fastapi import ( |
|||
APIRouter, |
|||
BackgroundTasks, |
|||
Body, |
|||
Cookie, |
|||
Depends, |
|||
FastAPI, |
|||
File, |
|||
Form, |
|||
Header, |
|||
Path, |
|||
Query, |
|||
Request, |
|||
WebSocket, |
|||
) |
|||
from fastapi.dependencies.utils import get_endpoint_dependant |
|||
from fastapi.exceptions import ( |
|||
DependencyScopeConflict, |
|||
InvalidDependencyScope, |
|||
UninitializedLifespanDependency, |
|||
) |
|||
from fastapi.params import Security |
|||
from fastapi.security import SecurityScopes |
|||
from fastapi.testclient import TestClient |
|||
from typing_extensions import Annotated, Literal |
|||
|
|||
from tests.test_lifespan_scoped_dependencies.testing_utilities import ( |
|||
DependencyFactory, |
|||
DependencyStyle, |
|||
IntentionallyBadDependency, |
|||
create_endpoint_0_annotations, |
|||
create_endpoint_1_annotation, |
|||
create_endpoint_2_annotations, |
|||
create_endpoint_3_annotations, |
|||
use_endpoint, |
|||
use_websocket, |
|||
) |
|||
|
|||
|
|||
def expect_correct_amount_of_dependency_activations( |
|||
*, |
|||
app: FastAPI, |
|||
dependency_factory: DependencyFactory, |
|||
urls_and_responses: List[Tuple[str, Any]], |
|||
expected_activation_times: int, |
|||
is_websocket: bool, |
|||
) -> None: |
|||
assert dependency_factory.activation_times == 0 |
|||
assert dependency_factory.deactivation_times == 0 |
|||
with TestClient(app) as client: |
|||
assert dependency_factory.activation_times == expected_activation_times |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
for url, expected_response in urls_and_responses: |
|||
if is_websocket: |
|||
assert use_websocket(client, url) == expected_response |
|||
else: |
|||
assert use_endpoint(client, url) == expected_response |
|||
|
|||
assert dependency_factory.activation_times == expected_activation_times |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
assert dependency_factory.activation_times == expected_activation_times |
|||
if dependency_factory.dependency_style not in ( |
|||
DependencyStyle.SYNC_FUNCTION, |
|||
DependencyStyle.ASYNC_FUNCTION, |
|||
): |
|||
assert dependency_factory.deactivation_times == expected_activation_times |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize( |
|||
"use_cache", [True, False], ids=["With Cache", "Without Cache"] |
|||
) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoint_dependencies( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
is_websocket: bool, |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[ |
|||
None, |
|||
Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
), |
|||
], |
|||
expected_value=1, |
|||
) |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
urls_and_responses=[("/test", 1)] * 2, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("dependency_duplication", [1, 2]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_router_dependencies( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
dependency_duplication, |
|||
is_websocket: bool, |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
if routing_style == "app": |
|||
app = FastAPI(dependencies=[depends] * dependency_duplication) |
|||
|
|||
create_endpoint_0_annotations( |
|||
router=app, path="/test", is_websocket=is_websocket |
|||
) |
|||
else: |
|||
app = FastAPI() |
|||
router = APIRouter(dependencies=[depends] * dependency_duplication) |
|||
|
|||
create_endpoint_0_annotations( |
|||
router=router, path="/test", is_websocket=is_websocket |
|||
) |
|||
|
|||
app.include_router(router) |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
urls_and_responses=[("/test", None)] * 2, |
|||
expected_activation_times=1 if use_cache else dependency_duplication, |
|||
is_websocket=is_websocket, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) |
|||
def test_dependency_cache_in_same_dependency( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
use_cache, |
|||
main_dependency_scope: Literal["endpoint", "lifespan"], |
|||
is_websocket: bool, |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def dependency( |
|||
sub_dependency1: Annotated[int, depends], |
|||
sub_dependency2: Annotated[int, depends], |
|||
) -> List[int]: |
|||
return [sub_dependency1, sub_dependency2] |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[ |
|||
List[int], |
|||
Depends( |
|||
dependency, |
|||
use_cache=use_cache, |
|||
dependency_scope=main_dependency_scope, |
|||
), |
|||
], |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
if use_cache: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [1, 1]), |
|||
("/test", [1, 1]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket, |
|||
) |
|||
else: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [1, 2]), |
|||
("/test", [1, 2]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=2, |
|||
is_websocket=is_websocket, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_dependency_cache_in_same_endpoint( |
|||
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
|||
return dependency3 |
|||
|
|||
create_endpoint_3_annotations( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation1=Annotated[int, depends], |
|||
annotation2=Annotated[int, depends], |
|||
annotation3=Annotated[int, Depends(endpoint_dependency)], |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
if use_cache: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [1, 1, 1]), |
|||
("/test", [1, 1, 1]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket, |
|||
) |
|||
else: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test", [1, 2, 3]), |
|||
("/test", [1, 2, 3]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=3, |
|||
is_websocket=is_websocket, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_dependency_cache_in_different_endpoints( |
|||
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
|||
return dependency3 |
|||
|
|||
create_endpoint_3_annotations( |
|||
router=router, |
|||
path="/test1", |
|||
is_websocket=is_websocket, |
|||
annotation1=Annotated[int, depends], |
|||
annotation2=Annotated[int, depends], |
|||
annotation3=Annotated[int, Depends(endpoint_dependency)], |
|||
) |
|||
|
|||
create_endpoint_3_annotations( |
|||
router=router, |
|||
path="/test2", |
|||
is_websocket=is_websocket, |
|||
annotation1=Annotated[int, depends], |
|||
annotation2=Annotated[int, depends], |
|||
annotation3=Annotated[int, Depends(endpoint_dependency)], |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
if use_cache: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [1, 1, 1]), |
|||
("/test2", [1, 1, 1]), |
|||
("/test1", [1, 1, 1]), |
|||
("/test2", [1, 1, 1]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket, |
|||
) |
|||
else: |
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
urls_and_responses=[ |
|||
("/test1", [1, 2, 3]), |
|||
("/test2", [4, 5, 3]), |
|||
("/test1", [1, 2, 3]), |
|||
("/test2", [4, 5, 3]), |
|||
], |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=5, |
|||
is_websocket=is_websocket, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_no_cached_dependency( |
|||
dependency_style: DependencyStyle, |
|||
routing_style, |
|||
is_websocket, |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=False, |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, depends], |
|||
expected_value=1, |
|||
) |
|||
|
|||
if routing_style == "router": |
|||
app.include_router(router) |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
urls_and_responses=[("/test", 1)] * 2, |
|||
expected_activation_times=1, |
|||
is_websocket=is_websocket, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize( |
|||
"annotation", |
|||
[ |
|||
Annotated[str, Path()], |
|||
Annotated[str, Body()], |
|||
Annotated[str, Query()], |
|||
Annotated[str, Header()], |
|||
SecurityScopes, |
|||
Annotated[str, Cookie()], |
|||
Annotated[str, Form()], |
|||
Annotated[str, File()], |
|||
BackgroundTasks, |
|||
Request, |
|||
WebSocket, |
|||
], |
|||
) |
|||
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( |
|||
annotation, is_websocket |
|||
): |
|||
async def dependency_func(param: annotation) -> None: |
|||
yield # pragma: nocover |
|||
|
|||
app = FastAPI() |
|||
|
|||
with pytest.raises(DependencyScopeConflict): |
|||
create_endpoint_1_annotation( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[ |
|||
None, Depends(dependency_func, dependency_scope="lifespan") |
|||
], |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( |
|||
dependency_style: DependencyStyle, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
async def lifespan_scoped_dependency( |
|||
param: Annotated[ |
|||
int, |
|||
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), |
|||
], |
|||
) -> AsyncGenerator[int, None]: |
|||
yield param |
|||
|
|||
app = FastAPI() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, Depends(lifespan_scoped_dependency)], |
|||
expected_value=1, |
|||
) |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1, |
|||
urls_and_responses=[("/test", 1)] * 2, |
|||
is_websocket=is_websocket, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize( |
|||
["dependency_style", "supports_teardown"], |
|||
[ |
|||
(DependencyStyle.SYNC_FUNCTION, False), |
|||
(DependencyStyle.ASYNC_FUNCTION, False), |
|||
(DependencyStyle.SYNC_GENERATOR, True), |
|||
(DependencyStyle.ASYNC_GENERATOR, True), |
|||
], |
|||
) |
|||
def test_the_same_dependency_can_work_in_different_scopes( |
|||
dependency_style: DependencyStyle, supports_teardown, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
app = FastAPI() |
|||
|
|||
create_endpoint_2_annotations( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation1=Annotated[ |
|||
int, |
|||
Depends(dependency_factory.get_dependency(), dependency_scope="endpoint"), |
|||
], |
|||
annotation2=Annotated[ |
|||
int, |
|||
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), |
|||
], |
|||
) |
|||
if is_websocket: |
|||
get_response = use_websocket |
|||
else: |
|||
get_response = use_endpoint |
|||
|
|||
assert dependency_factory.activation_times == 0 |
|||
assert dependency_factory.deactivation_times == 0 |
|||
with TestClient(app) as client: |
|||
assert dependency_factory.activation_times == 1 |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
assert get_response(client, "/test") == [2, 1] |
|||
assert dependency_factory.activation_times == 2 |
|||
if supports_teardown: |
|||
if is_websocket: |
|||
# Websockets teardown might take some time after the test client |
|||
# has disconnected |
|||
sleep(0.1) |
|||
assert dependency_factory.deactivation_times == 1 |
|||
else: |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
assert get_response(client, "/test") == [3, 1] |
|||
assert dependency_factory.activation_times == 3 |
|||
if supports_teardown: |
|||
if is_websocket: |
|||
# Websockets teardown might take some time after the test client |
|||
# has disconnected |
|||
sleep(0.1) |
|||
assert dependency_factory.deactivation_times == 2 |
|||
else: |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
assert dependency_factory.activation_times == 3 |
|||
if supports_teardown: |
|||
assert dependency_factory.deactivation_times == 3 |
|||
else: |
|||
assert dependency_factory.deactivation_times == 0 |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"] |
|||
) |
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( |
|||
dependency_style: DependencyStyle, |
|||
is_websocket, |
|||
lifespan_style: Literal["lifespan_function", "lifespan_events"], |
|||
): |
|||
lifespan_started = False |
|||
lifespan_ended = False |
|||
if lifespan_style == "lifespan_generator": |
|||
|
|||
@asynccontextmanager |
|||
async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, int], None]: |
|||
nonlocal lifespan_started |
|||
nonlocal lifespan_ended |
|||
lifespan_started = True |
|||
yield |
|||
lifespan_ended = True |
|||
|
|||
app = FastAPI(lifespan=lifespan) |
|||
elif lifespan_style == "events_decorator": |
|||
app = FastAPI() |
|||
with warnings.catch_warnings(record=True): |
|||
warnings.simplefilter("always") |
|||
|
|||
@app.on_event("startup") |
|||
async def startup() -> None: |
|||
nonlocal lifespan_started |
|||
lifespan_started = True |
|||
|
|||
@app.on_event("shutdown") |
|||
async def shutdown() -> None: |
|||
nonlocal lifespan_ended |
|||
lifespan_ended = True |
|||
else: |
|||
assert lifespan_style == "events_constructor" |
|||
|
|||
async def startup() -> None: |
|||
nonlocal lifespan_started |
|||
lifespan_started = True |
|||
|
|||
async def shutdown() -> None: |
|||
nonlocal lifespan_ended |
|||
lifespan_ended = True |
|||
|
|||
app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) |
|||
|
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[ |
|||
int, |
|||
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), |
|||
], |
|||
expected_value=1, |
|||
) |
|||
|
|||
expect_correct_amount_of_dependency_activations( |
|||
app=app, |
|||
dependency_factory=dependency_factory, |
|||
expected_activation_times=1, |
|||
urls_and_responses=[("/test", 1)] * 2, |
|||
is_websocket=is_websocket, |
|||
) |
|||
assert lifespan_started and lifespan_ended |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("depends_class", [Depends, Security]) |
|||
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( |
|||
depends_class, is_websocket |
|||
): |
|||
async def sub_dependency() -> None: |
|||
pass # pragma: nocover |
|||
|
|||
async def dependency_func( |
|||
param: Annotated[None, depends_class(sub_dependency)], |
|||
) -> None: |
|||
pass # pragma: nocover |
|||
|
|||
app = FastAPI() |
|||
|
|||
with pytest.raises(DependencyScopeConflict): |
|||
create_endpoint_1_annotation( |
|||
router=app, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[ |
|||
None, Depends(dependency_func, dependency_scope="lifespan") |
|||
], |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_dependencies_must_provide_correct_dependency_scope( |
|||
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
with pytest.raises( |
|||
InvalidDependencyScope, |
|||
match=r'Dependency "value" of .* has an invalid scope: ' r'"incorrect"', |
|||
): |
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[ |
|||
None, |
|||
Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="incorrect", |
|||
use_cache=use_cache, |
|||
), |
|||
], |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoints_report_incorrect_dependency_scope( |
|||
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
# We intentionally change the dependency scope here to bypass the |
|||
# validation at the function level. |
|||
depends.dependency_scope = "asdad" |
|||
|
|||
with pytest.raises(InvalidDependencyScope): |
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, depends], |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
|||
def test_endpoints_report_incorrect_dependency_scope_at_router_scope( |
|||
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) |
|||
|
|||
depends = Depends(dependency_factory.get_dependency(), dependency_scope="lifespan") |
|||
|
|||
# We intentionally change the dependency scope here to bypass the |
|||
# validation at the function level. |
|||
depends.dependency_scope = "asdad" |
|||
|
|||
if routing_style == "app": |
|||
app = FastAPI(dependencies=[depends]) |
|||
router = app |
|||
else: |
|||
router = APIRouter(dependencies=[depends]) |
|||
|
|||
with pytest.raises(InvalidDependencyScope): |
|||
create_endpoint_0_annotations( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoints_report_uninitialized_dependency( |
|||
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, depends], |
|||
expected_value=1, |
|||
) |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
with TestClient(app) as client: |
|||
dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] |
|||
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {} |
|||
|
|||
try: |
|||
with pytest.raises(UninitializedLifespanDependency): |
|||
if is_websocket: |
|||
with client.websocket_connect("/test"): |
|||
pass # pragma: nocover |
|||
else: |
|||
client.post("/test") |
|||
finally: |
|||
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = ( |
|||
dependencies |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_endpoints_report_uninitialized_internal_lifespan( |
|||
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
else: |
|||
router = APIRouter() |
|||
|
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, depends], |
|||
expected_value=1, |
|||
) |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
with TestClient(app) as client: |
|||
internal_state = client.app_state["__fastapi__"] |
|||
del client.app_state["__fastapi__"] |
|||
|
|||
try: |
|||
with pytest.raises(UninitializedLifespanDependency): |
|||
if is_websocket: |
|||
with client.websocket_connect("/test"): |
|||
pass # pragma: nocover |
|||
else: |
|||
client.post("/test") |
|||
finally: |
|||
client.app_state["__fastapi__"] = internal_state |
|||
|
|||
|
|||
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
|||
@pytest.mark.parametrize("use_cache", [True, False]) |
|||
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
|||
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
|||
def test_bad_lifespan_scoped_dependencies( |
|||
use_cache, dependency_style: DependencyStyle, routing_style, is_websocket |
|||
): |
|||
dependency_factory = DependencyFactory(dependency_style, should_error=True) |
|||
depends = Depends( |
|||
dependency_factory.get_dependency(), |
|||
dependency_scope="lifespan", |
|||
use_cache=use_cache, |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
if routing_style == "app_endpoint": |
|||
router = app |
|||
|
|||
else: |
|||
router = APIRouter() |
|||
|
|||
create_endpoint_1_annotation( |
|||
router=router, |
|||
path="/test", |
|||
is_websocket=is_websocket, |
|||
annotation=Annotated[int, depends], |
|||
expected_value=1, |
|||
) |
|||
|
|||
if routing_style == "router_endpoint": |
|||
app.include_router(router) |
|||
|
|||
with pytest.raises(IntentionallyBadDependency) as exception_info: |
|||
with TestClient(app): |
|||
pass |
|||
|
|||
assert exception_info.value.args == (1,) |
|||
|
|||
|
|||
def test_endpoint_dependant_backwards_compatibility(): |
|||
dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) |
|||
|
|||
def endpoint( |
|||
dependency1: Annotated[int, Depends(dependency_factory.get_dependency())], |
|||
dependency2: Annotated[ |
|||
int, |
|||
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), |
|||
], |
|||
): |
|||
pass # pragma: nocover |
|||
|
|||
dependant = get_endpoint_dependant( |
|||
path="/test", |
|||
call=endpoint, |
|||
name="endpoint", |
|||
) |
|||
|
|||
assert dependant.dependencies == tuple( |
|||
dependant.lifespan_dependencies + dependant.endpoint_dependencies |
|||
) |
@ -0,0 +1,193 @@ |
|||
from enum import Enum |
|||
from typing import Any, AsyncGenerator, Generator, List, TypeVar, Union |
|||
|
|||
from fastapi import APIRouter, FastAPI, WebSocket |
|||
from fastapi.testclient import TestClient |
|||
from typing_extensions import assert_never |
|||
|
|||
T = TypeVar("T") |
|||
|
|||
|
|||
class DependencyStyle(str, Enum): |
|||
SYNC_FUNCTION = "sync_function" |
|||
ASYNC_FUNCTION = "async_function" |
|||
SYNC_GENERATOR = "sync_generator" |
|||
ASYNC_GENERATOR = "async_generator" |
|||
|
|||
|
|||
class IntentionallyBadDependency(Exception): |
|||
pass |
|||
|
|||
|
|||
class DependencyFactory: |
|||
def __init__( |
|||
self, |
|||
dependency_style: DependencyStyle, |
|||
*, |
|||
should_error: bool = False, |
|||
value_offset: int = 0, |
|||
): |
|||
self.activation_times = 0 |
|||
self.deactivation_times = 0 |
|||
self.dependency_style = dependency_style |
|||
self._should_error = should_error |
|||
self._value_offset = value_offset |
|||
|
|||
def get_dependency(self): |
|||
if self.dependency_style == DependencyStyle.SYNC_FUNCTION: |
|||
return self._synchronous_function_dependency |
|||
|
|||
if self.dependency_style == DependencyStyle.SYNC_GENERATOR: |
|||
return self._synchronous_generator_dependency |
|||
|
|||
if self.dependency_style == DependencyStyle.ASYNC_FUNCTION: |
|||
return self._asynchronous_function_dependency |
|||
|
|||
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: |
|||
return self._asynchronous_generator_dependency |
|||
|
|||
assert_never(self.dependency_style) # pragma: nocover |
|||
|
|||
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise IntentionallyBadDependency(self.activation_times) |
|||
|
|||
yield self.activation_times + self._value_offset |
|||
self.deactivation_times += 1 |
|||
|
|||
def _synchronous_generator_dependency(self) -> Generator[T, None, None]: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise IntentionallyBadDependency(self.activation_times) |
|||
|
|||
yield self.activation_times + self._value_offset |
|||
self.deactivation_times += 1 |
|||
|
|||
async def _asynchronous_function_dependency(self) -> T: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise IntentionallyBadDependency(self.activation_times) |
|||
|
|||
return self.activation_times + self._value_offset |
|||
|
|||
def _synchronous_function_dependency(self) -> T: |
|||
self.activation_times += 1 |
|||
if self._should_error: |
|||
raise IntentionallyBadDependency(self.activation_times) |
|||
|
|||
return self.activation_times + self._value_offset |
|||
|
|||
|
|||
def use_endpoint(client: TestClient, url: str) -> Any: |
|||
response = client.post(url) |
|||
response.raise_for_status() |
|||
return response.json() |
|||
|
|||
|
|||
def use_websocket(client: TestClient, url: str) -> Any: |
|||
with client.websocket_connect(url) as connection: |
|||
return connection.receive_json() |
|||
|
|||
|
|||
def create_endpoint_0_annotations( |
|||
*, |
|||
router: Union[APIRouter, FastAPI], |
|||
path: str, |
|||
is_websocket: bool, |
|||
) -> None: |
|||
if is_websocket: |
|||
|
|||
@router.websocket(path) |
|||
async def endpoint(websocket: WebSocket) -> None: |
|||
await websocket.accept() |
|||
await websocket.send_json(None) |
|||
else: |
|||
|
|||
@router.post(path) |
|||
async def endpoint() -> None: |
|||
return None |
|||
|
|||
|
|||
def create_endpoint_1_annotation( |
|||
*, |
|||
router: Union[APIRouter, FastAPI], |
|||
path: str, |
|||
is_websocket: bool, |
|||
annotation: Any, |
|||
expected_value: Any = None, |
|||
) -> None: |
|||
if is_websocket: |
|||
|
|||
@router.websocket(path) |
|||
async def endpoint(websocket: WebSocket, value: annotation) -> None: |
|||
if expected_value is not None: |
|||
assert value == expected_value |
|||
|
|||
await websocket.accept() |
|||
await websocket.send_json(value) |
|||
else: |
|||
|
|||
@router.post(path) |
|||
async def endpoint(value: annotation) -> Any: |
|||
if expected_value is not None: |
|||
assert value == expected_value |
|||
|
|||
return value |
|||
|
|||
|
|||
def create_endpoint_2_annotations( |
|||
*, |
|||
router: Union[APIRouter, FastAPI], |
|||
path: str, |
|||
is_websocket: bool, |
|||
annotation1: Any, |
|||
annotation2: Any, |
|||
) -> None: |
|||
if is_websocket: |
|||
|
|||
@router.websocket(path) |
|||
async def endpoint( |
|||
websocket: WebSocket, |
|||
value1: annotation1, |
|||
value2: annotation2, |
|||
) -> None: |
|||
await websocket.accept() |
|||
await websocket.send_json([value1, value2]) |
|||
else: |
|||
|
|||
@router.post(path) |
|||
async def endpoint( |
|||
value1: annotation1, |
|||
value2: annotation2, |
|||
) -> List[Any]: |
|||
return [value1, value2] |
|||
|
|||
|
|||
def create_endpoint_3_annotations( |
|||
*, |
|||
router: Union[APIRouter, FastAPI], |
|||
path: str, |
|||
is_websocket: bool, |
|||
annotation1: Any, |
|||
annotation2: Any, |
|||
annotation3: Any, |
|||
) -> None: |
|||
if is_websocket: |
|||
|
|||
@router.websocket(path) |
|||
async def endpoint( |
|||
websocket: WebSocket, |
|||
value1: annotation1, |
|||
value2: annotation2, |
|||
value3: annotation3, |
|||
) -> None: |
|||
await websocket.accept() |
|||
await websocket.send_json([value1, value2, value3]) |
|||
else: |
|||
|
|||
@router.post(path) |
|||
async def endpoint( |
|||
value1: annotation1, value2: annotation2, value3: annotation3 |
|||
) -> List[Any]: |
|||
return [value1, value2, value3] |
@ -0,0 +1,65 @@ |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
from docs_src.dependencies.tutorial013a import MyDatabaseConnection, app |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self): |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_records_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_records(self, table_name: str) -> List[dict]: |
|||
self.get_records_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_records(self, table_name) |
|||
return [] |
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mock(monkeypatch) -> MockDatabaseConnection: |
|||
mock = MockDatabaseConnection() |
|||
|
|||
monkeypatch.setattr(MyDatabaseConnection, "__new__", lambda *args, **kwargs: mock) |
|||
|
|||
return mock |
|||
|
|||
|
|||
def test_dependency_usage(database_connection_mock): |
|||
assert database_connection_mock.enter_count == 0 |
|||
assert database_connection_mock.exit_count == 0 |
|||
with TestClient(app) as test_client: |
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
|
|||
response = test_client.get("/users") |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
assert database_connection_mock.get_records_count == 1 |
|||
|
|||
response = test_client.get("/items") |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
assert database_connection_mock.get_records_count == 2 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 1 |
@ -0,0 +1,70 @@ |
|||
import sys |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
if sys.version_info >= (3, 9): |
|||
from docs_src.dependencies.tutorial013a_an_py39 import MyDatabaseConnection, app |
|||
|
|||
from ...utils import needs_py39 |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self): |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_records_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_records(self, table_name: str) -> List[dict]: |
|||
self.get_records_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_records(self, table_name) |
|||
return [] |
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mock(monkeypatch) -> MockDatabaseConnection: |
|||
mock = MockDatabaseConnection() |
|||
|
|||
monkeypatch.setattr(MyDatabaseConnection, "__new__", lambda *args, **kwargs: mock) |
|||
|
|||
return mock |
|||
|
|||
|
|||
@needs_py39 |
|||
def test_dependency_usage(database_connection_mock): |
|||
assert database_connection_mock.enter_count == 0 |
|||
assert database_connection_mock.exit_count == 0 |
|||
with TestClient(app) as test_client: |
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
|
|||
response = test_client.get("/users") |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
assert database_connection_mock.get_records_count == 1 |
|||
|
|||
response = test_client.get("/items") |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
assert database_connection_mock.get_records_count == 2 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 1 |
@ -0,0 +1,130 @@ |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
from docs_src.dependencies.tutorial013b import MyDatabaseConnection, app |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self): |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_records_count = 0 |
|||
self.get_record_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_records(self, table_name: str) -> List[dict]: |
|||
self.get_records_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_records(self, table_name) |
|||
return [] |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
self.get_record_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_record(self, table_name, record_id) |
|||
return { |
|||
"table_name": table_name, |
|||
"record_id": record_id, |
|||
} |
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
|||
connections = [] |
|||
|
|||
def _get_new_connection_mock(*args, **kwargs): |
|||
mock = MockDatabaseConnection() |
|||
connections.append(mock) |
|||
|
|||
return mock |
|||
|
|||
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) |
|||
return connections |
|||
|
|||
|
|||
def test_dependency_usage(database_connection_mocks): |
|||
assert len(database_connection_mocks) == 0 |
|||
|
|||
with TestClient(app) as test_client: |
|||
assert len(database_connection_mocks) == 3 |
|||
for connection in database_connection_mocks: |
|||
assert connection.enter_count == 1 |
|||
assert connection.exit_count == 0 |
|||
assert connection.get_records_count == 0 |
|||
assert connection.get_record_count == 0 |
|||
|
|||
response = test_client.get("/users") |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
users_connection = None |
|||
for connection in database_connection_mocks: |
|||
if connection.get_records_count == 1: |
|||
users_connection = connection |
|||
break |
|||
|
|||
assert users_connection is not None, ( |
|||
"No connection was found for users endpoint" |
|||
) |
|||
|
|||
response = test_client.get("/groups") |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
groups_connection = None |
|||
for connection in database_connection_mocks: |
|||
if connection.get_records_count == 1 and connection is not users_connection: |
|||
groups_connection = connection |
|||
break |
|||
|
|||
assert groups_connection is not None, ( |
|||
"No connection was found for groups endpoint" |
|||
) |
|||
assert groups_connection.get_records_count == 1 |
|||
|
|||
items_connection = None |
|||
for connection in database_connection_mocks: |
|||
if connection.get_records_count == 0: |
|||
items_connection = connection |
|||
break |
|||
|
|||
assert items_connection is not None, ( |
|||
"No connection was found for items endpoint" |
|||
) |
|||
|
|||
response = test_client.get("/items") |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
assert items_connection.get_records_count == 1 |
|||
assert items_connection.get_record_count == 0 |
|||
|
|||
response = test_client.get("/items/asd") |
|||
assert response.status_code == 200 |
|||
assert response.json() == { |
|||
"table_name": "items", |
|||
"record_id": "asd", |
|||
} |
|||
|
|||
assert items_connection.get_records_count == 1 |
|||
assert items_connection.get_record_count == 1 |
|||
|
|||
for connection in database_connection_mocks: |
|||
assert connection.enter_count == 1 |
|||
assert connection.exit_count == 0 |
|||
|
|||
for connection in database_connection_mocks: |
|||
assert connection.enter_count == 1 |
|||
assert connection.exit_count == 1 |
@ -0,0 +1,135 @@ |
|||
import sys |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
if sys.version_info >= (3, 9): |
|||
from docs_src.dependencies.tutorial013b_an_py39 import MyDatabaseConnection, app |
|||
|
|||
from ...utils import needs_py39 |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self): |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_records_count = 0 |
|||
self.get_record_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_records(self, table_name: str) -> List[dict]: |
|||
self.get_records_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_records(self, table_name) |
|||
return [] |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
self.get_record_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_record(self, table_name, record_id) |
|||
return { |
|||
"table_name": table_name, |
|||
"record_id": record_id, |
|||
} |
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
|||
connections = [] |
|||
|
|||
def _get_new_connection_mock(*args, **kwargs): |
|||
mock = MockDatabaseConnection() |
|||
connections.append(mock) |
|||
|
|||
return mock |
|||
|
|||
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) |
|||
return connections |
|||
|
|||
|
|||
@needs_py39 |
|||
def test_dependency_usage(database_connection_mocks): |
|||
assert len(database_connection_mocks) == 0 |
|||
|
|||
with TestClient(app) as test_client: |
|||
assert len(database_connection_mocks) == 3 |
|||
for connection in database_connection_mocks: |
|||
assert connection.enter_count == 1 |
|||
assert connection.exit_count == 0 |
|||
assert connection.get_records_count == 0 |
|||
assert connection.get_record_count == 0 |
|||
|
|||
response = test_client.get("/users") |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
users_connection = None |
|||
for connection in database_connection_mocks: |
|||
if connection.get_records_count == 1: |
|||
users_connection = connection |
|||
break |
|||
|
|||
assert users_connection is not None, ( |
|||
"No connection was found for users endpoint" |
|||
) |
|||
|
|||
response = test_client.get("/groups") |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
groups_connection = None |
|||
for connection in database_connection_mocks: |
|||
if connection.get_records_count == 1 and connection is not users_connection: |
|||
groups_connection = connection |
|||
break |
|||
|
|||
assert groups_connection is not None, ( |
|||
"No connection was found for groups endpoint" |
|||
) |
|||
assert groups_connection.get_records_count == 1 |
|||
|
|||
items_connection = None |
|||
for connection in database_connection_mocks: |
|||
if connection.get_records_count == 0: |
|||
items_connection = connection |
|||
break |
|||
|
|||
assert items_connection is not None, ( |
|||
"No connection was found for items endpoint" |
|||
) |
|||
|
|||
response = test_client.get("/items") |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
assert items_connection.get_records_count == 1 |
|||
assert items_connection.get_record_count == 0 |
|||
|
|||
response = test_client.get("/items/asd") |
|||
assert response.status_code == 200 |
|||
assert response.json() == { |
|||
"table_name": "items", |
|||
"record_id": "asd", |
|||
} |
|||
|
|||
assert items_connection.get_records_count == 1 |
|||
assert items_connection.get_record_count == 1 |
|||
|
|||
for connection in database_connection_mocks: |
|||
assert connection.enter_count == 1 |
|||
assert connection.exit_count == 0 |
|||
|
|||
for connection in database_connection_mocks: |
|||
assert connection.enter_count == 1 |
|||
assert connection.exit_count == 1 |
@ -0,0 +1,78 @@ |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
from docs_src.dependencies.tutorial013c import MyDatabaseConnection, app |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self, url: str): |
|||
self.url = url |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_record_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
self.get_record_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_record(self, table_name, record_id) |
|||
return { |
|||
"table_name": table_name, |
|||
"record_id": record_id, |
|||
} |
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
|||
connections = [] |
|||
|
|||
def _get_new_connection_mock(cls, url): |
|||
mock = MockDatabaseConnection(url) |
|||
connections.append(mock) |
|||
|
|||
return mock |
|||
|
|||
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) |
|||
return connections |
|||
|
|||
|
|||
def test_dependency_usage(database_connection_mocks): |
|||
assert len(database_connection_mocks) == 0 |
|||
|
|||
with TestClient(app) as test_client: |
|||
assert len(database_connection_mocks) == 1 |
|||
[database_connection_mock] = database_connection_mocks |
|||
|
|||
assert database_connection_mock.url == "sqlite:///database.db" |
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 0 |
|||
|
|||
response = test_client.get("/users/user") |
|||
assert response.status_code == 200 |
|||
assert response.json() == { |
|||
"table_name": "users", |
|||
"record_id": "user", |
|||
} |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 1 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert len(database_connection_mocks) == 1 |
@ -0,0 +1,83 @@ |
|||
import sys |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
if sys.version_info >= (3, 9): |
|||
from docs_src.dependencies.tutorial013c_an_py39 import MyDatabaseConnection, app |
|||
|
|||
from ...utils import needs_py39 |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self, url: str): |
|||
self.url = url |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_record_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
self.get_record_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_record(self, table_name, record_id) |
|||
return { |
|||
"table_name": table_name, |
|||
"record_id": record_id, |
|||
} |
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
|||
connections = [] |
|||
|
|||
def _get_new_connection_mock(cls, url): |
|||
mock = MockDatabaseConnection(url) |
|||
connections.append(mock) |
|||
|
|||
return mock |
|||
|
|||
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) |
|||
return connections |
|||
|
|||
|
|||
@needs_py39 |
|||
def test_dependency_usage(database_connection_mocks): |
|||
assert len(database_connection_mocks) == 0 |
|||
|
|||
with TestClient(app) as test_client: |
|||
assert len(database_connection_mocks) == 1 |
|||
[database_connection_mock] = database_connection_mocks |
|||
|
|||
assert database_connection_mock.url == "sqlite:///database.db" |
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 0 |
|||
|
|||
response = test_client.get("/users/user") |
|||
assert response.status_code == 200 |
|||
assert response.json() == { |
|||
"table_name": "users", |
|||
"record_id": "user", |
|||
} |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 1 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert len(database_connection_mocks) == 1 |
@ -0,0 +1,76 @@ |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
from docs_src.dependencies.tutorial013d import MyDatabaseConnection, app |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self): |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_record_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
self.get_record_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_record(self, table_name, record_id) |
|||
return { |
|||
"table_name": table_name, |
|||
"record_id": record_id, |
|||
} |
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
|||
connections = [] |
|||
|
|||
def _get_new_connection_mock(*args, **kwargs): |
|||
mock = MockDatabaseConnection() |
|||
connections.append(mock) |
|||
|
|||
return mock |
|||
|
|||
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) |
|||
return connections |
|||
|
|||
|
|||
def test_dependency_usage(database_connection_mocks): |
|||
assert len(database_connection_mocks) == 0 |
|||
|
|||
with TestClient(app) as test_client: |
|||
assert len(database_connection_mocks) == 1 |
|||
[database_connection_mock] = database_connection_mocks |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 0 |
|||
|
|||
response = test_client.get("/users/user") |
|||
assert response.status_code == 200 |
|||
assert response.json() == { |
|||
"table_name": "users", |
|||
"record_id": "user", |
|||
} |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 1 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert len(database_connection_mocks) == 1 |
@ -0,0 +1,81 @@ |
|||
import sys |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
if sys.version_info >= (3, 9): |
|||
from docs_src.dependencies.tutorial013d_an_py39 import MyDatabaseConnection, app |
|||
|
|||
from ...utils import needs_py39 |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self): |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_record_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
self.get_record_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_record(self, table_name, record_id) |
|||
return { |
|||
"table_name": table_name, |
|||
"record_id": record_id, |
|||
} |
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
|||
connections = [] |
|||
|
|||
def _get_new_connection_mock(*args, **kwargs): |
|||
mock = MockDatabaseConnection() |
|||
connections.append(mock) |
|||
|
|||
return mock |
|||
|
|||
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) |
|||
return connections |
|||
|
|||
|
|||
@needs_py39 |
|||
def test_dependency_usage(database_connection_mocks): |
|||
assert len(database_connection_mocks) == 0 |
|||
|
|||
with TestClient(app) as test_client: |
|||
assert len(database_connection_mocks) == 1 |
|||
[database_connection_mock] = database_connection_mocks |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 0 |
|||
|
|||
response = test_client.get("/users/user") |
|||
assert response.status_code == 200 |
|||
assert response.json() == { |
|||
"table_name": "users", |
|||
"record_id": "user", |
|||
} |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 1 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert len(database_connection_mocks) == 1 |
Loading…
Reference in new issue