committed by
GitHub
33 changed files with 3513 additions and 92 deletions
@ -0,0 +1,111 @@ |
|||||
|
# Lifespan Scoped Dependencies |
||||
|
|
||||
|
## Intro |
||||
|
|
||||
|
So far we've used dependencies which are "endpoint scoped". Meaning, they are |
||||
|
called again and again for every incoming request to the endpoint. However, |
||||
|
this is not always ideal: |
||||
|
|
||||
|
* Sometimes dependencies have a large setup/teardown time. Running it for every request will result in bad performance. |
||||
|
* Sometimes dependencies need to have their values shared throughout the lifespan |
||||
|
of the application between multiple requests. |
||||
|
|
||||
|
|
||||
|
An example of this would be a connection to a database. Databases are typically |
||||
|
less efficient when working with lots of connections and would prefer that |
||||
|
clients would create a single connection for their operations. |
||||
|
|
||||
|
For such cases can be solved by using "lifespan scoped dependencies". |
||||
|
|
||||
|
|
||||
|
## What is a lifespan scoped dependency? |
||||
|
Lifespan scoped dependencies work similarly to the (endpoint scoped) |
||||
|
dependencies we've worked with so far. However, unlike endpoint scoped |
||||
|
dependencies, lifespan scoped dependencies are called once and only |
||||
|
once in the application's lifespan: |
||||
|
|
||||
|
* During the application startup process, all lifespan scoped dependencies will |
||||
|
be called. |
||||
|
* Their returned value will be shared across all requests to the application. |
||||
|
* During the application's shutdown process, all lifespan scoped dependencies |
||||
|
will be gracefully teared down. |
||||
|
|
||||
|
|
||||
|
## Create a lifespan scoped dependency |
||||
|
|
||||
|
You may declare a dependency as a lifespan scoped dependency by passing |
||||
|
`dependency_scope="lifespan"` to the `Depends` function: |
||||
|
|
||||
|
{* ../../docs_src/dependencies/tutorial013a_an_py39.py *} |
||||
|
|
||||
|
/// tip |
||||
|
|
||||
|
In the example above we saved the annotation to a separate variable, and then |
||||
|
reused it in our endpoints. This is not a requirement, we could also declare |
||||
|
the exact same annotation in both endpoints. However, it is recommended that you |
||||
|
do save the annotation to a variable so you won't accidentally forget to pass |
||||
|
`dependency_scope="lifespan"` to some of the endpoints (Causing the endpoint |
||||
|
to create a new database connection for every request). |
||||
|
|
||||
|
/// |
||||
|
|
||||
|
In this example, the `get_database_connection` dependency will be executed once, |
||||
|
during the application's startup. **FastAPI** will internally save the resulting |
||||
|
connection object, and whenever the `read_users` and `read_items` endpoints are |
||||
|
called, they will be using the previously saved connection. Once the application |
||||
|
shuts down, **FastAPI** will make sure to gracefully close the connection object. |
||||
|
|
||||
|
## The `use_cache` argument |
||||
|
|
||||
|
The `use_cache` argument works similarly to the way it worked with endpoint |
||||
|
scoped dependencies. Meaning as **FastAPI** gathers lifespan scoped dependencies, it |
||||
|
will cache dependencies it already encountered before. However, you can disable |
||||
|
this behavior by passing `use_cache=False` to `Depends`: |
||||
|
|
||||
|
{* ../../docs_src/dependencies/tutorial013b_an_py39.py *} |
||||
|
|
||||
|
In this example, the `read_users` and `read_groups` endpoints are using |
||||
|
`use_cache=False` whereas the `read_items` and `read_item` are using |
||||
|
`use_cache=True`. |
||||
|
That means that we'll have a total of 3 connections created |
||||
|
for the duration of the application's lifespan: |
||||
|
|
||||
|
* One connection will be shared across all requests for the `read_items` and `read_item` endpoints. |
||||
|
* A second connection will be shared across all requests for the `read_users` endpoint. |
||||
|
* A third and final connection will be shared across all requests for the `read_groups` endpoint. |
||||
|
|
||||
|
|
||||
|
## Lifespan Scoped Sub-Dependencies |
||||
|
Just like with endpoint scoped dependencies, lifespan scoped dependencies may |
||||
|
use other lifespan scoped sub-dependencies themselves: |
||||
|
|
||||
|
{* ../../docs_src/dependencies/tutorial013c_an_py39.py *} |
||||
|
|
||||
|
Endpoint scoped dependencies may use lifespan scoped sub dependencies as well: |
||||
|
|
||||
|
{* ../../docs_src/dependencies/tutorial013d_an_py39.py *} |
||||
|
|
||||
|
/// note |
||||
|
|
||||
|
You can pass `dependency_scope="endpoint"` if you wish to explicitly specify |
||||
|
that a dependency is endpoint scoped. It will work the same as not specifying |
||||
|
a dependency scope at all. |
||||
|
|
||||
|
/// |
||||
|
|
||||
|
As you can see, regardless of the scope, dependencies can use lifespan scoped |
||||
|
sub-dependencies. |
||||
|
|
||||
|
## Dependency Scope Conflicts |
||||
|
By definition, lifespan scoped dependencies are being setup in the application's |
||||
|
startup process, before any request is ever being made to any endpoint. |
||||
|
Therefore, it is not possible for a lifespan scoped dependency to use any |
||||
|
parameters that require the scope of an endpoint. |
||||
|
|
||||
|
That includes but not limited to: |
||||
|
|
||||
|
* Parts of the request (like `Body`, `Query` and `Path`) |
||||
|
* The request/response objects themselves (like `Request`, `Response` and `WebSocket`) |
||||
|
* Endpoint scoped sub-dependencies. |
||||
|
|
||||
|
Defining a dependency with such parameters will raise an `InvalidDependencyScope` error. |
@ -0,0 +1,44 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
from fastapi import Depends, FastAPI |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> List[dict]: |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_database_connection(): |
||||
|
async with MyDatabaseConnection() as connection: |
||||
|
yield connection |
||||
|
|
||||
|
|
||||
|
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
||||
|
|
||||
|
|
||||
|
@app.get("/users/") |
||||
|
async def read_users( |
||||
|
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
||||
|
): |
||||
|
return await database_connection.get_records("users") |
||||
|
|
||||
|
|
||||
|
@app.get("/items/") |
||||
|
async def read_items( |
||||
|
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
||||
|
): |
||||
|
return await database_connection.get_records("items") |
@ -0,0 +1,42 @@ |
|||||
|
from typing import Annotated |
||||
|
|
||||
|
from fastapi import Depends, FastAPI |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> list[dict]: |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_database_connection(): |
||||
|
async with MyDatabaseConnection() as connection: |
||||
|
yield connection |
||||
|
|
||||
|
|
||||
|
GlobalDatabaseConnection = Annotated[ |
||||
|
MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan") |
||||
|
] |
||||
|
|
||||
|
|
||||
|
@app.get("/users/") |
||||
|
async def read_users(database_connection: GlobalDatabaseConnection): |
||||
|
return await database_connection.get_records("users") |
||||
|
|
||||
|
|
||||
|
@app.get("/items/") |
||||
|
async def read_items(database_connection: GlobalDatabaseConnection): |
||||
|
return await database_connection.get_records("items") |
@ -0,0 +1,65 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, Path |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> List[dict]: |
||||
|
pass |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_database_connection(): |
||||
|
async with MyDatabaseConnection() as connection: |
||||
|
yield connection |
||||
|
|
||||
|
|
||||
|
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
||||
|
DedicatedDatabaseConnection = Depends( |
||||
|
get_database_connection, dependency_scope="lifespan", use_cache=False |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@app.get("/groups/") |
||||
|
async def read_groups( |
||||
|
database_connection: MyDatabaseConnection = DedicatedDatabaseConnection, |
||||
|
): |
||||
|
return await database_connection.get_records("groups") |
||||
|
|
||||
|
|
||||
|
@app.get("/users/") |
||||
|
async def read_users( |
||||
|
database_connection: MyDatabaseConnection = DedicatedDatabaseConnection, |
||||
|
): |
||||
|
return await database_connection.get_records("users") |
||||
|
|
||||
|
|
||||
|
@app.get("/items/") |
||||
|
async def read_items( |
||||
|
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
||||
|
): |
||||
|
return await database_connection.get_records("items") |
||||
|
|
||||
|
|
||||
|
@app.get("/items/{item_id}") |
||||
|
async def read_item( |
||||
|
item_id: str = Path(), |
||||
|
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
||||
|
): |
||||
|
return await database_connection.get_record("items", item_id) |
@ -0,0 +1,61 @@ |
|||||
|
from typing import Annotated |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, Path |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> list[dict]: |
||||
|
pass |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_database_connection(): |
||||
|
async with MyDatabaseConnection() as connection: |
||||
|
yield connection |
||||
|
|
||||
|
|
||||
|
GlobalDatabaseConnection = Annotated[ |
||||
|
MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan") |
||||
|
] |
||||
|
DedicatedDatabaseConnection = Annotated[ |
||||
|
MyDatabaseConnection, |
||||
|
Depends(get_database_connection, dependency_scope="lifespan", use_cache=False), |
||||
|
] |
||||
|
|
||||
|
|
||||
|
@app.get("/groups/") |
||||
|
async def read_groups(database_connection: DedicatedDatabaseConnection): |
||||
|
return await database_connection.get_records("groups") |
||||
|
|
||||
|
|
||||
|
@app.get("/users/") |
||||
|
async def read_users(database_connection: DedicatedDatabaseConnection): |
||||
|
return await database_connection.get_records("users") |
||||
|
|
||||
|
|
||||
|
@app.get("/items/") |
||||
|
async def read_items(database_connection: GlobalDatabaseConnection): |
||||
|
return await database_connection.get_records("items") |
||||
|
|
||||
|
|
||||
|
@app.get("/items/{item_id}") |
||||
|
async def read_item( |
||||
|
database_connection: GlobalDatabaseConnection, item_id: Annotated[str, Path()] |
||||
|
): |
||||
|
return await database_connection.get_record("items", item_id) |
@ -0,0 +1,50 @@ |
|||||
|
from dataclasses import dataclass |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, Path |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
@dataclass |
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
|
||||
|
connection_string: str |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_configuration() -> dict: |
||||
|
return { |
||||
|
"database_url": "sqlite:///database.db", |
||||
|
} |
||||
|
|
||||
|
|
||||
|
GlobalConfiguration = Depends(get_configuration, dependency_scope="lifespan") |
||||
|
|
||||
|
|
||||
|
async def get_database_connection(configuration: dict = GlobalConfiguration): |
||||
|
async with MyDatabaseConnection(configuration["database_url"]) as connection: |
||||
|
yield connection |
||||
|
|
||||
|
|
||||
|
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
||||
|
|
||||
|
|
||||
|
@app.get("/users/{user_id}") |
||||
|
async def read_user( |
||||
|
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
||||
|
user_id: str = Path(), |
||||
|
): |
||||
|
return await database_connection.get_record("users", user_id) |
@ -0,0 +1,55 @@ |
|||||
|
from dataclasses import dataclass |
||||
|
from typing import Annotated |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, Path |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
@dataclass |
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
|
||||
|
connection_string: str |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_configuration() -> dict: |
||||
|
return { |
||||
|
"database_url": "sqlite:///database.db", |
||||
|
} |
||||
|
|
||||
|
|
||||
|
GlobalConfiguration = Annotated[ |
||||
|
dict, Depends(get_configuration, dependency_scope="lifespan") |
||||
|
] |
||||
|
|
||||
|
|
||||
|
async def get_database_connection(configuration: GlobalConfiguration): |
||||
|
async with MyDatabaseConnection(configuration["database_url"]) as connection: |
||||
|
yield connection |
||||
|
|
||||
|
|
||||
|
GlobalDatabaseConnection = Annotated[ |
||||
|
get_database_connection, |
||||
|
Depends(get_database_connection, dependency_scope="lifespan"), |
||||
|
] |
||||
|
|
||||
|
|
||||
|
@app.get("/users/{user_id}") |
||||
|
async def read_user( |
||||
|
database_connection: GlobalDatabaseConnection, user_id: Annotated[str, Path()] |
||||
|
): |
||||
|
return await database_connection.get_record("users", user_id) |
@ -0,0 +1,40 @@ |
|||||
|
from fastapi import Depends, FastAPI, Path |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_database_connection(): |
||||
|
async with MyDatabaseConnection() as connection: |
||||
|
yield connection |
||||
|
|
||||
|
|
||||
|
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
||||
|
|
||||
|
|
||||
|
async def get_user_record( |
||||
|
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
||||
|
user_id: str = Path(), |
||||
|
) -> dict: |
||||
|
return await database_connection.get_record("users", user_id) |
||||
|
|
||||
|
|
||||
|
@app.get("/users/{user_id}") |
||||
|
async def read_user(user_record: dict = Depends(get_user_record)): |
||||
|
return user_record |
@ -0,0 +1,43 @@ |
|||||
|
from typing import Annotated |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, Path |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_database_connection(): |
||||
|
async with MyDatabaseConnection() as connection: |
||||
|
yield connection |
||||
|
|
||||
|
|
||||
|
GlobalDatabaseConnection = Annotated[ |
||||
|
MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan") |
||||
|
] |
||||
|
|
||||
|
|
||||
|
async def get_user_record( |
||||
|
database_connection: GlobalDatabaseConnection, user_id: Annotated[str, Path()] |
||||
|
) -> dict: |
||||
|
return await database_connection.get_record("users", user_id) |
||||
|
|
||||
|
|
||||
|
@app.get("/users/{user_id}") |
||||
|
async def read_user(user_record: Annotated[dict, Depends(get_user_record)]): |
||||
|
return user_record |
@ -0,0 +1,44 @@ |
|||||
|
from __future__ import annotations |
||||
|
|
||||
|
from contextlib import AsyncExitStack |
||||
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List |
||||
|
|
||||
|
from fastapi.dependencies.models import LifespanDependant, LifespanDependantCacheKey |
||||
|
from fastapi.dependencies.utils import solve_lifespan_dependant |
||||
|
from fastapi.routing import APIRoute, APIWebSocketRoute |
||||
|
|
||||
|
if TYPE_CHECKING: # pragma: nocover |
||||
|
from fastapi import FastAPI |
||||
|
|
||||
|
|
||||
|
def _get_lifespan_dependants(app: FastAPI) -> List[LifespanDependant]: |
||||
|
lifespan_dependants_cache: Dict[LifespanDependantCacheKey, LifespanDependant] = {} |
||||
|
for route in app.router.routes: |
||||
|
if not isinstance(route, (APIWebSocketRoute, APIRoute)): |
||||
|
continue |
||||
|
|
||||
|
for sub_dependant in route.lifespan_dependencies: |
||||
|
if sub_dependant.cache_key in lifespan_dependants_cache: |
||||
|
continue |
||||
|
|
||||
|
lifespan_dependants_cache[sub_dependant.cache_key] = sub_dependant |
||||
|
|
||||
|
return list(lifespan_dependants_cache.values()) |
||||
|
|
||||
|
|
||||
|
async def resolve_lifespan_dependants( |
||||
|
*, app: FastAPI, async_exit_stack: AsyncExitStack |
||||
|
) -> Dict[LifespanDependantCacheKey, Callable[..., Any]]: |
||||
|
lifespan_dependants = _get_lifespan_dependants(app) |
||||
|
dependency_cache: Dict[LifespanDependantCacheKey, Callable[..., Any]] = {} |
||||
|
for lifespan_dependant in lifespan_dependants: |
||||
|
solved_dependency = await solve_lifespan_dependant( |
||||
|
dependant=lifespan_dependant, |
||||
|
dependency_overrides_provider=app, |
||||
|
dependency_cache=dependency_cache, |
||||
|
async_exit_stack=async_exit_stack, |
||||
|
) |
||||
|
|
||||
|
dependency_cache.update(solved_dependency.dependency_cache) |
||||
|
|
||||
|
return dependency_cache |
@ -0,0 +1,624 @@ |
|||||
|
from typing import Any, AsyncGenerator, List, Tuple |
||||
|
|
||||
|
import pytest |
||||
|
from fastapi import ( |
||||
|
APIRouter, |
||||
|
BackgroundTasks, |
||||
|
Body, |
||||
|
Cookie, |
||||
|
Depends, |
||||
|
FastAPI, |
||||
|
File, |
||||
|
Form, |
||||
|
Header, |
||||
|
Path, |
||||
|
Query, |
||||
|
Request, |
||||
|
WebSocket, |
||||
|
) |
||||
|
from fastapi.exceptions import DependencyScopeConflict |
||||
|
from fastapi.params import Security |
||||
|
from fastapi.security import SecurityScopes |
||||
|
from fastapi.testclient import TestClient |
||||
|
from typing_extensions import Annotated, Literal |
||||
|
|
||||
|
from tests.test_lifespan_scoped_dependencies.testing_utilities import ( |
||||
|
DependencyFactory, |
||||
|
DependencyStyle, |
||||
|
IntentionallyBadDependency, |
||||
|
create_endpoint_0_annotations, |
||||
|
create_endpoint_1_annotation, |
||||
|
create_endpoint_3_annotations, |
||||
|
use_endpoint, |
||||
|
use_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
def expect_correct_amount_of_dependency_activations( |
||||
|
*, |
||||
|
app: FastAPI, |
||||
|
dependency_factory: DependencyFactory, |
||||
|
override_dependency_factory: DependencyFactory, |
||||
|
urls_and_responses: List[Tuple[str, Any]], |
||||
|
expected_activation_times: int, |
||||
|
is_websocket: bool, |
||||
|
) -> None: |
||||
|
assert dependency_factory.activation_times == 0 |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
assert override_dependency_factory.activation_times == 0 |
||||
|
assert override_dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
with TestClient(app) as client: |
||||
|
assert dependency_factory.activation_times == 0 |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
assert override_dependency_factory.activation_times == expected_activation_times |
||||
|
assert override_dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
for url, expected_response in urls_and_responses: |
||||
|
if is_websocket: |
||||
|
response = use_websocket(client, url) |
||||
|
else: |
||||
|
response = use_endpoint(client, url) |
||||
|
|
||||
|
assert response == expected_response |
||||
|
|
||||
|
assert dependency_factory.activation_times == 0 |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
assert ( |
||||
|
override_dependency_factory.activation_times |
||||
|
== expected_activation_times |
||||
|
) |
||||
|
assert override_dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
assert dependency_factory.activation_times == 0 |
||||
|
assert override_dependency_factory.activation_times == expected_activation_times |
||||
|
if dependency_factory.dependency_style not in ( |
||||
|
DependencyStyle.SYNC_FUNCTION, |
||||
|
DependencyStyle.ASYNC_FUNCTION, |
||||
|
): |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
assert ( |
||||
|
override_dependency_factory.deactivation_times == expected_activation_times |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_endpoint_dependencies( |
||||
|
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[ |
||||
|
None, |
||||
|
Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
), |
||||
|
], |
||||
|
expected_value=11, |
||||
|
) |
||||
|
if routing_style == "router_endpoint": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
||||
|
override_dependency_factory.get_dependency() |
||||
|
) |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
urls_and_responses=[("/test", 11)] * 2, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("dependency_duplication", [1, 2]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_router_dependencies( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
dependency_duplication, |
||||
|
is_websocket, |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
app = FastAPI(dependencies=[depends] * dependency_duplication) |
||||
|
|
||||
|
create_endpoint_0_annotations( |
||||
|
router=app, path="/test", is_websocket=is_websocket |
||||
|
) |
||||
|
else: |
||||
|
app = FastAPI() |
||||
|
router = APIRouter(dependencies=[depends] * dependency_duplication) |
||||
|
|
||||
|
create_endpoint_0_annotations( |
||||
|
router=router, path="/test", is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
app.include_router(router) |
||||
|
|
||||
|
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
||||
|
override_dependency_factory.get_dependency() |
||||
|
) |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
urls_and_responses=[("/test", None)] * 2, |
||||
|
expected_activation_times=1 if use_cache else dependency_duplication, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) |
||||
|
def test_dependency_cache_in_same_dependency( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
main_dependency_scope: Literal["endpoint", "lifespan"], |
||||
|
is_websocket, |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
async def dependency( |
||||
|
sub_dependency1: Annotated[int, depends], |
||||
|
sub_dependency2: Annotated[int, depends], |
||||
|
) -> List[int]: |
||||
|
return [sub_dependency1, sub_dependency2] |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[ |
||||
|
List[int], |
||||
|
Depends( |
||||
|
dependency, |
||||
|
use_cache=use_cache, |
||||
|
dependency_scope=main_dependency_scope, |
||||
|
), |
||||
|
], |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
||||
|
override_dependency_factory.get_dependency() |
||||
|
) |
||||
|
|
||||
|
if use_cache: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test", [11, 11]), |
||||
|
("/test", [11, 11]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
else: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test", [11, 12]), |
||||
|
("/test", [11, 12]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
expected_activation_times=2, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_dependency_cache_in_same_endpoint( |
||||
|
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
||||
|
return dependency3 |
||||
|
|
||||
|
create_endpoint_3_annotations( |
||||
|
router=router, |
||||
|
path="/test1", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation1=Annotated[int, depends], |
||||
|
annotation2=Annotated[int, depends], |
||||
|
annotation3=Annotated[int, Depends(endpoint_dependency)], |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
||||
|
override_dependency_factory.get_dependency() |
||||
|
) |
||||
|
|
||||
|
if use_cache: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test1", [11, 11, 11]), |
||||
|
("/test1", [11, 11, 11]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
else: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test1", [11, 12, 13]), |
||||
|
("/test1", [11, 12, 13]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
expected_activation_times=3, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_dependency_cache_in_different_endpoints( |
||||
|
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
||||
|
return dependency3 |
||||
|
|
||||
|
create_endpoint_3_annotations( |
||||
|
router=router, |
||||
|
path="/test1", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation1=Annotated[int, depends], |
||||
|
annotation2=Annotated[int, depends], |
||||
|
annotation3=Annotated[int, Depends(endpoint_dependency)], |
||||
|
) |
||||
|
|
||||
|
create_endpoint_3_annotations( |
||||
|
router=router, |
||||
|
path="/test2", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation1=Annotated[int, depends], |
||||
|
annotation2=Annotated[int, depends], |
||||
|
annotation3=Annotated[int, Depends(endpoint_dependency)], |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
||||
|
override_dependency_factory.get_dependency() |
||||
|
) |
||||
|
|
||||
|
if use_cache: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test1", [11, 11, 11]), |
||||
|
("/test2", [11, 11, 11]), |
||||
|
("/test1", [11, 11, 11]), |
||||
|
("/test2", [11, 11, 11]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
else: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test1", [11, 12, 13]), |
||||
|
("/test2", [14, 15, 13]), |
||||
|
("/test1", [11, 12, 13]), |
||||
|
("/test2", [14, 15, 13]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
expected_activation_times=5, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_no_cached_dependency( |
||||
|
dependency_style: DependencyStyle, routing_style, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=False, |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, depends], |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
||||
|
override_dependency_factory.get_dependency() |
||||
|
) |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
urls_and_responses=[("/test", 11)] * 2, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize( |
||||
|
"annotation", |
||||
|
[ |
||||
|
Annotated[str, Path()], |
||||
|
Annotated[str, Body()], |
||||
|
Annotated[str, Query()], |
||||
|
Annotated[str, Header()], |
||||
|
SecurityScopes, |
||||
|
Annotated[str, Cookie()], |
||||
|
Annotated[str, Form()], |
||||
|
Annotated[str, File()], |
||||
|
BackgroundTasks, |
||||
|
Request, |
||||
|
WebSocket, |
||||
|
], |
||||
|
) |
||||
|
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( |
||||
|
annotation, is_websocket |
||||
|
): |
||||
|
async def dependency_func() -> None: |
||||
|
yield # pragma: nocover |
||||
|
|
||||
|
async def override_dependency_func(param: annotation) -> None: |
||||
|
yield # pragma: nocover |
||||
|
|
||||
|
app = FastAPI() |
||||
|
app.dependency_overrides[dependency_func] = override_dependency_func |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[ |
||||
|
None, Depends(dependency_func, dependency_scope="lifespan") |
||||
|
], |
||||
|
) |
||||
|
|
||||
|
with pytest.raises(DependencyScopeConflict): |
||||
|
with TestClient(app): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
def test_non_override_lifespan_scoped_dependency_can_use_overridden_lifespan_scoped_dependencies( |
||||
|
dependency_style: DependencyStyle, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory(dependency_style, value_offset=10) |
||||
|
|
||||
|
async def lifespan_scoped_dependency( |
||||
|
param: Annotated[ |
||||
|
int, |
||||
|
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), |
||||
|
], |
||||
|
) -> AsyncGenerator[int, None]: |
||||
|
yield param |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[ |
||||
|
int, Depends(lifespan_scoped_dependency, dependency_scope="lifespan") |
||||
|
], |
||||
|
) |
||||
|
|
||||
|
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
||||
|
override_dependency_factory.get_dependency() |
||||
|
) |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
override_dependency_factory=override_dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
urls_and_responses=[("/test", 11)] * 2, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("depends_class", [Depends, Security]) |
||||
|
def test_override_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( |
||||
|
depends_class, is_websocket |
||||
|
): |
||||
|
async def sub_dependency() -> None: |
||||
|
pass # pragma: nocover |
||||
|
|
||||
|
async def dependency_func() -> None: |
||||
|
yield # pragma: nocover |
||||
|
|
||||
|
async def override_dependency_func( |
||||
|
param: Annotated[None, depends_class(sub_dependency)], |
||||
|
) -> None: |
||||
|
yield # pragma: nocover |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[ |
||||
|
None, Depends(dependency_func, dependency_scope="lifespan") |
||||
|
], |
||||
|
) |
||||
|
|
||||
|
app.dependency_overrides[dependency_func] = override_dependency_func |
||||
|
|
||||
|
with pytest.raises(DependencyScopeConflict): |
||||
|
with TestClient(app): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_bad_override_lifespan_scoped_dependencies( |
||||
|
use_cache, dependency_style: DependencyStyle, routing_style, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
override_dependency_factory = DependencyFactory(dependency_style, should_error=True) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, depends], |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router_endpoint": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
app.dependency_overrides[dependency_factory.get_dependency()] = ( |
||||
|
override_dependency_factory.get_dependency() |
||||
|
) |
||||
|
|
||||
|
with pytest.raises(IntentionallyBadDependency) as exception_info: |
||||
|
with TestClient(app): |
||||
|
pass |
||||
|
|
||||
|
assert exception_info.value.args == (1,) |
@ -0,0 +1,920 @@ |
|||||
|
import warnings |
||||
|
from contextlib import asynccontextmanager |
||||
|
from time import sleep |
||||
|
from typing import Any, AsyncGenerator, Dict, List, Tuple |
||||
|
|
||||
|
import pytest |
||||
|
from fastapi import ( |
||||
|
APIRouter, |
||||
|
BackgroundTasks, |
||||
|
Body, |
||||
|
Cookie, |
||||
|
Depends, |
||||
|
FastAPI, |
||||
|
File, |
||||
|
Form, |
||||
|
Header, |
||||
|
Path, |
||||
|
Query, |
||||
|
Request, |
||||
|
WebSocket, |
||||
|
) |
||||
|
from fastapi.dependencies.utils import get_endpoint_dependant |
||||
|
from fastapi.exceptions import ( |
||||
|
DependencyScopeConflict, |
||||
|
InvalidDependencyScope, |
||||
|
UninitializedLifespanDependency, |
||||
|
) |
||||
|
from fastapi.params import Security |
||||
|
from fastapi.security import SecurityScopes |
||||
|
from fastapi.testclient import TestClient |
||||
|
from typing_extensions import Annotated, Literal |
||||
|
|
||||
|
from tests.test_lifespan_scoped_dependencies.testing_utilities import ( |
||||
|
DependencyFactory, |
||||
|
DependencyStyle, |
||||
|
IntentionallyBadDependency, |
||||
|
create_endpoint_0_annotations, |
||||
|
create_endpoint_1_annotation, |
||||
|
create_endpoint_2_annotations, |
||||
|
create_endpoint_3_annotations, |
||||
|
use_endpoint, |
||||
|
use_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
def expect_correct_amount_of_dependency_activations( |
||||
|
*, |
||||
|
app: FastAPI, |
||||
|
dependency_factory: DependencyFactory, |
||||
|
urls_and_responses: List[Tuple[str, Any]], |
||||
|
expected_activation_times: int, |
||||
|
is_websocket: bool, |
||||
|
) -> None: |
||||
|
assert dependency_factory.activation_times == 0 |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
with TestClient(app) as client: |
||||
|
assert dependency_factory.activation_times == expected_activation_times |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
for url, expected_response in urls_and_responses: |
||||
|
if is_websocket: |
||||
|
assert use_websocket(client, url) == expected_response |
||||
|
else: |
||||
|
assert use_endpoint(client, url) == expected_response |
||||
|
|
||||
|
assert dependency_factory.activation_times == expected_activation_times |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
assert dependency_factory.activation_times == expected_activation_times |
||||
|
if dependency_factory.dependency_style not in ( |
||||
|
DependencyStyle.SYNC_FUNCTION, |
||||
|
DependencyStyle.ASYNC_FUNCTION, |
||||
|
): |
||||
|
assert dependency_factory.deactivation_times == expected_activation_times |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize( |
||||
|
"use_cache", [True, False], ids=["With Cache", "Without Cache"] |
||||
|
) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_endpoint_dependencies( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
is_websocket: bool, |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[ |
||||
|
None, |
||||
|
Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
), |
||||
|
], |
||||
|
expected_value=1, |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router_endpoint": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
urls_and_responses=[("/test", 1)] * 2, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("dependency_duplication", [1, 2]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_router_dependencies( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
dependency_duplication, |
||||
|
is_websocket: bool, |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
app = FastAPI(dependencies=[depends] * dependency_duplication) |
||||
|
|
||||
|
create_endpoint_0_annotations( |
||||
|
router=app, path="/test", is_websocket=is_websocket |
||||
|
) |
||||
|
else: |
||||
|
app = FastAPI() |
||||
|
router = APIRouter(dependencies=[depends] * dependency_duplication) |
||||
|
|
||||
|
create_endpoint_0_annotations( |
||||
|
router=router, path="/test", is_websocket=is_websocket |
||||
|
) |
||||
|
|
||||
|
app.include_router(router) |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
urls_and_responses=[("/test", None)] * 2, |
||||
|
expected_activation_times=1 if use_cache else dependency_duplication, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
@pytest.mark.parametrize("main_dependency_scope", ["endpoint", "lifespan"]) |
||||
|
def test_dependency_cache_in_same_dependency( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
use_cache, |
||||
|
main_dependency_scope: Literal["endpoint", "lifespan"], |
||||
|
is_websocket: bool, |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
async def dependency( |
||||
|
sub_dependency1: Annotated[int, depends], |
||||
|
sub_dependency2: Annotated[int, depends], |
||||
|
) -> List[int]: |
||||
|
return [sub_dependency1, sub_dependency2] |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[ |
||||
|
List[int], |
||||
|
Depends( |
||||
|
dependency, |
||||
|
use_cache=use_cache, |
||||
|
dependency_scope=main_dependency_scope, |
||||
|
), |
||||
|
], |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
if use_cache: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test", [1, 1]), |
||||
|
("/test", [1, 1]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
else: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test", [1, 2]), |
||||
|
("/test", [1, 2]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=2, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_dependency_cache_in_same_endpoint( |
||||
|
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
||||
|
return dependency3 |
||||
|
|
||||
|
create_endpoint_3_annotations( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation1=Annotated[int, depends], |
||||
|
annotation2=Annotated[int, depends], |
||||
|
annotation3=Annotated[int, Depends(endpoint_dependency)], |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
if use_cache: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test", [1, 1, 1]), |
||||
|
("/test", [1, 1, 1]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
else: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test", [1, 2, 3]), |
||||
|
("/test", [1, 2, 3]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=3, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_dependency_cache_in_different_endpoints( |
||||
|
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
async def endpoint_dependency(dependency3: Annotated[int, depends]) -> int: |
||||
|
return dependency3 |
||||
|
|
||||
|
create_endpoint_3_annotations( |
||||
|
router=router, |
||||
|
path="/test1", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation1=Annotated[int, depends], |
||||
|
annotation2=Annotated[int, depends], |
||||
|
annotation3=Annotated[int, Depends(endpoint_dependency)], |
||||
|
) |
||||
|
|
||||
|
create_endpoint_3_annotations( |
||||
|
router=router, |
||||
|
path="/test2", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation1=Annotated[int, depends], |
||||
|
annotation2=Annotated[int, depends], |
||||
|
annotation3=Annotated[int, Depends(endpoint_dependency)], |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
if use_cache: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test1", [1, 1, 1]), |
||||
|
("/test2", [1, 1, 1]), |
||||
|
("/test1", [1, 1, 1]), |
||||
|
("/test2", [1, 1, 1]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
else: |
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
urls_and_responses=[ |
||||
|
("/test1", [1, 2, 3]), |
||||
|
("/test2", [4, 5, 3]), |
||||
|
("/test1", [1, 2, 3]), |
||||
|
("/test2", [4, 5, 3]), |
||||
|
], |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=5, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_no_cached_dependency( |
||||
|
dependency_style: DependencyStyle, |
||||
|
routing_style, |
||||
|
is_websocket, |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=False, |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, depends], |
||||
|
expected_value=1, |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
urls_and_responses=[("/test", 1)] * 2, |
||||
|
expected_activation_times=1, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize( |
||||
|
"annotation", |
||||
|
[ |
||||
|
Annotated[str, Path()], |
||||
|
Annotated[str, Body()], |
||||
|
Annotated[str, Query()], |
||||
|
Annotated[str, Header()], |
||||
|
SecurityScopes, |
||||
|
Annotated[str, Cookie()], |
||||
|
Annotated[str, Form()], |
||||
|
Annotated[str, File()], |
||||
|
BackgroundTasks, |
||||
|
Request, |
||||
|
WebSocket, |
||||
|
], |
||||
|
) |
||||
|
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_parameters( |
||||
|
annotation, is_websocket |
||||
|
): |
||||
|
async def dependency_func(param: annotation) -> None: |
||||
|
yield # pragma: nocover |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
with pytest.raises(DependencyScopeConflict): |
||||
|
create_endpoint_1_annotation( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[ |
||||
|
None, Depends(dependency_func, dependency_scope="lifespan") |
||||
|
], |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
def test_lifespan_scoped_dependency_can_use_other_lifespan_scoped_dependencies( |
||||
|
dependency_style: DependencyStyle, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
async def lifespan_scoped_dependency( |
||||
|
param: Annotated[ |
||||
|
int, |
||||
|
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), |
||||
|
], |
||||
|
) -> AsyncGenerator[int, None]: |
||||
|
yield param |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, Depends(lifespan_scoped_dependency)], |
||||
|
expected_value=1, |
||||
|
) |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
urls_and_responses=[("/test", 1)] * 2, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize( |
||||
|
["dependency_style", "supports_teardown"], |
||||
|
[ |
||||
|
(DependencyStyle.SYNC_FUNCTION, False), |
||||
|
(DependencyStyle.ASYNC_FUNCTION, False), |
||||
|
(DependencyStyle.SYNC_GENERATOR, True), |
||||
|
(DependencyStyle.ASYNC_GENERATOR, True), |
||||
|
], |
||||
|
) |
||||
|
def test_the_same_dependency_can_work_in_different_scopes( |
||||
|
dependency_style: DependencyStyle, supports_teardown, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
app = FastAPI() |
||||
|
|
||||
|
create_endpoint_2_annotations( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation1=Annotated[ |
||||
|
int, |
||||
|
Depends(dependency_factory.get_dependency(), dependency_scope="endpoint"), |
||||
|
], |
||||
|
annotation2=Annotated[ |
||||
|
int, |
||||
|
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), |
||||
|
], |
||||
|
) |
||||
|
if is_websocket: |
||||
|
get_response = use_websocket |
||||
|
else: |
||||
|
get_response = use_endpoint |
||||
|
|
||||
|
assert dependency_factory.activation_times == 0 |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
with TestClient(app) as client: |
||||
|
assert dependency_factory.activation_times == 1 |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
assert get_response(client, "/test") == [2, 1] |
||||
|
assert dependency_factory.activation_times == 2 |
||||
|
if supports_teardown: |
||||
|
if is_websocket: |
||||
|
# Websockets teardown might take some time after the test client |
||||
|
# has disconnected |
||||
|
sleep(0.1) |
||||
|
assert dependency_factory.deactivation_times == 1 |
||||
|
else: |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
assert get_response(client, "/test") == [3, 1] |
||||
|
assert dependency_factory.activation_times == 3 |
||||
|
if supports_teardown: |
||||
|
if is_websocket: |
||||
|
# Websockets teardown might take some time after the test client |
||||
|
# has disconnected |
||||
|
sleep(0.1) |
||||
|
assert dependency_factory.deactivation_times == 2 |
||||
|
else: |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
assert dependency_factory.activation_times == 3 |
||||
|
if supports_teardown: |
||||
|
assert dependency_factory.deactivation_times == 3 |
||||
|
else: |
||||
|
assert dependency_factory.deactivation_times == 0 |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize( |
||||
|
"lifespan_style", ["lifespan_generator", "events_decorator", "events_constructor"] |
||||
|
) |
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
def test_lifespan_scoped_dependency_can_be_used_alongside_custom_lifespans( |
||||
|
dependency_style: DependencyStyle, |
||||
|
is_websocket, |
||||
|
lifespan_style: Literal["lifespan_function", "lifespan_events"], |
||||
|
): |
||||
|
lifespan_started = False |
||||
|
lifespan_ended = False |
||||
|
if lifespan_style == "lifespan_generator": |
||||
|
|
||||
|
@asynccontextmanager |
||||
|
async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, int], None]: |
||||
|
nonlocal lifespan_started |
||||
|
nonlocal lifespan_ended |
||||
|
lifespan_started = True |
||||
|
yield |
||||
|
lifespan_ended = True |
||||
|
|
||||
|
app = FastAPI(lifespan=lifespan) |
||||
|
elif lifespan_style == "events_decorator": |
||||
|
app = FastAPI() |
||||
|
with warnings.catch_warnings(record=True): |
||||
|
warnings.simplefilter("always") |
||||
|
|
||||
|
@app.on_event("startup") |
||||
|
async def startup() -> None: |
||||
|
nonlocal lifespan_started |
||||
|
lifespan_started = True |
||||
|
|
||||
|
@app.on_event("shutdown") |
||||
|
async def shutdown() -> None: |
||||
|
nonlocal lifespan_ended |
||||
|
lifespan_ended = True |
||||
|
else: |
||||
|
assert lifespan_style == "events_constructor" |
||||
|
|
||||
|
async def startup() -> None: |
||||
|
nonlocal lifespan_started |
||||
|
lifespan_started = True |
||||
|
|
||||
|
async def shutdown() -> None: |
||||
|
nonlocal lifespan_ended |
||||
|
lifespan_ended = True |
||||
|
|
||||
|
app = FastAPI(on_startup=[startup], on_shutdown=[shutdown]) |
||||
|
|
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[ |
||||
|
int, |
||||
|
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), |
||||
|
], |
||||
|
expected_value=1, |
||||
|
) |
||||
|
|
||||
|
expect_correct_amount_of_dependency_activations( |
||||
|
app=app, |
||||
|
dependency_factory=dependency_factory, |
||||
|
expected_activation_times=1, |
||||
|
urls_and_responses=[("/test", 1)] * 2, |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
assert lifespan_started and lifespan_ended |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("depends_class", [Depends, Security]) |
||||
|
def test_lifespan_scoped_dependency_cannot_use_endpoint_scoped_dependencies( |
||||
|
depends_class, is_websocket |
||||
|
): |
||||
|
async def sub_dependency() -> None: |
||||
|
pass # pragma: nocover |
||||
|
|
||||
|
async def dependency_func( |
||||
|
param: Annotated[None, depends_class(sub_dependency)], |
||||
|
) -> None: |
||||
|
pass # pragma: nocover |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
with pytest.raises(DependencyScopeConflict): |
||||
|
create_endpoint_1_annotation( |
||||
|
router=app, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[ |
||||
|
None, Depends(dependency_func, dependency_scope="lifespan") |
||||
|
], |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_dependencies_must_provide_correct_dependency_scope( |
||||
|
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
with pytest.raises( |
||||
|
InvalidDependencyScope, |
||||
|
match=r'Dependency "value" of .* has an invalid scope: ' r'"incorrect"', |
||||
|
): |
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[ |
||||
|
None, |
||||
|
Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="incorrect", |
||||
|
use_cache=use_cache, |
||||
|
), |
||||
|
], |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_endpoints_report_incorrect_dependency_scope( |
||||
|
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
# We intentionally change the dependency scope here to bypass the |
||||
|
# validation at the function level. |
||||
|
depends.dependency_scope = "asdad" |
||||
|
|
||||
|
with pytest.raises(InvalidDependencyScope): |
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, depends], |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app", "router"]) |
||||
|
def test_endpoints_report_incorrect_dependency_scope_at_router_scope( |
||||
|
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) |
||||
|
|
||||
|
depends = Depends(dependency_factory.get_dependency(), dependency_scope="lifespan") |
||||
|
|
||||
|
# We intentionally change the dependency scope here to bypass the |
||||
|
# validation at the function level. |
||||
|
depends.dependency_scope = "asdad" |
||||
|
|
||||
|
if routing_style == "app": |
||||
|
app = FastAPI(dependencies=[depends]) |
||||
|
router = app |
||||
|
else: |
||||
|
router = APIRouter(dependencies=[depends]) |
||||
|
|
||||
|
with pytest.raises(InvalidDependencyScope): |
||||
|
create_endpoint_0_annotations( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_endpoints_report_uninitialized_dependency( |
||||
|
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, depends], |
||||
|
expected_value=1, |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router_endpoint": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
with TestClient(app) as client: |
||||
|
dependencies = client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] |
||||
|
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = {} |
||||
|
|
||||
|
try: |
||||
|
with pytest.raises(UninitializedLifespanDependency): |
||||
|
if is_websocket: |
||||
|
with client.websocket_connect("/test"): |
||||
|
pass # pragma: nocover |
||||
|
else: |
||||
|
client.post("/test") |
||||
|
finally: |
||||
|
client.app_state["__fastapi__"]["lifespan_scoped_dependencies"] = ( |
||||
|
dependencies |
||||
|
) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_endpoints_report_uninitialized_internal_lifespan( |
||||
|
dependency_style: DependencyStyle, routing_style, use_cache, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, depends], |
||||
|
expected_value=1, |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router_endpoint": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
with TestClient(app) as client: |
||||
|
internal_state = client.app_state["__fastapi__"] |
||||
|
del client.app_state["__fastapi__"] |
||||
|
|
||||
|
try: |
||||
|
with pytest.raises(UninitializedLifespanDependency): |
||||
|
if is_websocket: |
||||
|
with client.websocket_connect("/test"): |
||||
|
pass # pragma: nocover |
||||
|
else: |
||||
|
client.post("/test") |
||||
|
finally: |
||||
|
client.app_state["__fastapi__"] = internal_state |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("is_websocket", [True, False], ids=["Websocket", "Endpoint"]) |
||||
|
@pytest.mark.parametrize("use_cache", [True, False]) |
||||
|
@pytest.mark.parametrize("dependency_style", list(DependencyStyle)) |
||||
|
@pytest.mark.parametrize("routing_style", ["app_endpoint", "router_endpoint"]) |
||||
|
def test_bad_lifespan_scoped_dependencies( |
||||
|
use_cache, dependency_style: DependencyStyle, routing_style, is_websocket |
||||
|
): |
||||
|
dependency_factory = DependencyFactory(dependency_style, should_error=True) |
||||
|
depends = Depends( |
||||
|
dependency_factory.get_dependency(), |
||||
|
dependency_scope="lifespan", |
||||
|
use_cache=use_cache, |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
if routing_style == "app_endpoint": |
||||
|
router = app |
||||
|
|
||||
|
else: |
||||
|
router = APIRouter() |
||||
|
|
||||
|
create_endpoint_1_annotation( |
||||
|
router=router, |
||||
|
path="/test", |
||||
|
is_websocket=is_websocket, |
||||
|
annotation=Annotated[int, depends], |
||||
|
expected_value=1, |
||||
|
) |
||||
|
|
||||
|
if routing_style == "router_endpoint": |
||||
|
app.include_router(router) |
||||
|
|
||||
|
with pytest.raises(IntentionallyBadDependency) as exception_info: |
||||
|
with TestClient(app): |
||||
|
pass |
||||
|
|
||||
|
assert exception_info.value.args == (1,) |
||||
|
|
||||
|
|
||||
|
def test_endpoint_dependant_backwards_compatibility(): |
||||
|
dependency_factory = DependencyFactory(DependencyStyle.ASYNC_GENERATOR) |
||||
|
|
||||
|
def endpoint( |
||||
|
dependency1: Annotated[int, Depends(dependency_factory.get_dependency())], |
||||
|
dependency2: Annotated[ |
||||
|
int, |
||||
|
Depends(dependency_factory.get_dependency(), dependency_scope="lifespan"), |
||||
|
], |
||||
|
): |
||||
|
pass # pragma: nocover |
||||
|
|
||||
|
dependant = get_endpoint_dependant( |
||||
|
path="/test", |
||||
|
call=endpoint, |
||||
|
name="endpoint", |
||||
|
) |
||||
|
|
||||
|
assert dependant.dependencies == tuple( |
||||
|
dependant.lifespan_dependencies + dependant.endpoint_dependencies |
||||
|
) |
@ -0,0 +1,193 @@ |
|||||
|
from enum import Enum |
||||
|
from typing import Any, AsyncGenerator, Generator, List, TypeVar, Union |
||||
|
|
||||
|
from fastapi import APIRouter, FastAPI, WebSocket |
||||
|
from fastapi.testclient import TestClient |
||||
|
from typing_extensions import assert_never |
||||
|
|
||||
|
T = TypeVar("T") |
||||
|
|
||||
|
|
||||
|
class DependencyStyle(str, Enum): |
||||
|
SYNC_FUNCTION = "sync_function" |
||||
|
ASYNC_FUNCTION = "async_function" |
||||
|
SYNC_GENERATOR = "sync_generator" |
||||
|
ASYNC_GENERATOR = "async_generator" |
||||
|
|
||||
|
|
||||
|
class IntentionallyBadDependency(Exception): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
class DependencyFactory: |
||||
|
def __init__( |
||||
|
self, |
||||
|
dependency_style: DependencyStyle, |
||||
|
*, |
||||
|
should_error: bool = False, |
||||
|
value_offset: int = 0, |
||||
|
): |
||||
|
self.activation_times = 0 |
||||
|
self.deactivation_times = 0 |
||||
|
self.dependency_style = dependency_style |
||||
|
self._should_error = should_error |
||||
|
self._value_offset = value_offset |
||||
|
|
||||
|
def get_dependency(self): |
||||
|
if self.dependency_style == DependencyStyle.SYNC_FUNCTION: |
||||
|
return self._synchronous_function_dependency |
||||
|
|
||||
|
if self.dependency_style == DependencyStyle.SYNC_GENERATOR: |
||||
|
return self._synchronous_generator_dependency |
||||
|
|
||||
|
if self.dependency_style == DependencyStyle.ASYNC_FUNCTION: |
||||
|
return self._asynchronous_function_dependency |
||||
|
|
||||
|
if self.dependency_style == DependencyStyle.ASYNC_GENERATOR: |
||||
|
return self._asynchronous_generator_dependency |
||||
|
|
||||
|
assert_never(self.dependency_style) # pragma: nocover |
||||
|
|
||||
|
async def _asynchronous_generator_dependency(self) -> AsyncGenerator[T, None]: |
||||
|
self.activation_times += 1 |
||||
|
if self._should_error: |
||||
|
raise IntentionallyBadDependency(self.activation_times) |
||||
|
|
||||
|
yield self.activation_times + self._value_offset |
||||
|
self.deactivation_times += 1 |
||||
|
|
||||
|
def _synchronous_generator_dependency(self) -> Generator[T, None, None]: |
||||
|
self.activation_times += 1 |
||||
|
if self._should_error: |
||||
|
raise IntentionallyBadDependency(self.activation_times) |
||||
|
|
||||
|
yield self.activation_times + self._value_offset |
||||
|
self.deactivation_times += 1 |
||||
|
|
||||
|
async def _asynchronous_function_dependency(self) -> T: |
||||
|
self.activation_times += 1 |
||||
|
if self._should_error: |
||||
|
raise IntentionallyBadDependency(self.activation_times) |
||||
|
|
||||
|
return self.activation_times + self._value_offset |
||||
|
|
||||
|
def _synchronous_function_dependency(self) -> T: |
||||
|
self.activation_times += 1 |
||||
|
if self._should_error: |
||||
|
raise IntentionallyBadDependency(self.activation_times) |
||||
|
|
||||
|
return self.activation_times + self._value_offset |
||||
|
|
||||
|
|
||||
|
def use_endpoint(client: TestClient, url: str) -> Any: |
||||
|
response = client.post(url) |
||||
|
response.raise_for_status() |
||||
|
return response.json() |
||||
|
|
||||
|
|
||||
|
def use_websocket(client: TestClient, url: str) -> Any: |
||||
|
with client.websocket_connect(url) as connection: |
||||
|
return connection.receive_json() |
||||
|
|
||||
|
|
||||
|
def create_endpoint_0_annotations( |
||||
|
*, |
||||
|
router: Union[APIRouter, FastAPI], |
||||
|
path: str, |
||||
|
is_websocket: bool, |
||||
|
) -> None: |
||||
|
if is_websocket: |
||||
|
|
||||
|
@router.websocket(path) |
||||
|
async def endpoint(websocket: WebSocket) -> None: |
||||
|
await websocket.accept() |
||||
|
await websocket.send_json(None) |
||||
|
else: |
||||
|
|
||||
|
@router.post(path) |
||||
|
async def endpoint() -> None: |
||||
|
return None |
||||
|
|
||||
|
|
||||
|
def create_endpoint_1_annotation( |
||||
|
*, |
||||
|
router: Union[APIRouter, FastAPI], |
||||
|
path: str, |
||||
|
is_websocket: bool, |
||||
|
annotation: Any, |
||||
|
expected_value: Any = None, |
||||
|
) -> None: |
||||
|
if is_websocket: |
||||
|
|
||||
|
@router.websocket(path) |
||||
|
async def endpoint(websocket: WebSocket, value: annotation) -> None: |
||||
|
if expected_value is not None: |
||||
|
assert value == expected_value |
||||
|
|
||||
|
await websocket.accept() |
||||
|
await websocket.send_json(value) |
||||
|
else: |
||||
|
|
||||
|
@router.post(path) |
||||
|
async def endpoint(value: annotation) -> Any: |
||||
|
if expected_value is not None: |
||||
|
assert value == expected_value |
||||
|
|
||||
|
return value |
||||
|
|
||||
|
|
||||
|
def create_endpoint_2_annotations( |
||||
|
*, |
||||
|
router: Union[APIRouter, FastAPI], |
||||
|
path: str, |
||||
|
is_websocket: bool, |
||||
|
annotation1: Any, |
||||
|
annotation2: Any, |
||||
|
) -> None: |
||||
|
if is_websocket: |
||||
|
|
||||
|
@router.websocket(path) |
||||
|
async def endpoint( |
||||
|
websocket: WebSocket, |
||||
|
value1: annotation1, |
||||
|
value2: annotation2, |
||||
|
) -> None: |
||||
|
await websocket.accept() |
||||
|
await websocket.send_json([value1, value2]) |
||||
|
else: |
||||
|
|
||||
|
@router.post(path) |
||||
|
async def endpoint( |
||||
|
value1: annotation1, |
||||
|
value2: annotation2, |
||||
|
) -> List[Any]: |
||||
|
return [value1, value2] |
||||
|
|
||||
|
|
||||
|
def create_endpoint_3_annotations( |
||||
|
*, |
||||
|
router: Union[APIRouter, FastAPI], |
||||
|
path: str, |
||||
|
is_websocket: bool, |
||||
|
annotation1: Any, |
||||
|
annotation2: Any, |
||||
|
annotation3: Any, |
||||
|
) -> None: |
||||
|
if is_websocket: |
||||
|
|
||||
|
@router.websocket(path) |
||||
|
async def endpoint( |
||||
|
websocket: WebSocket, |
||||
|
value1: annotation1, |
||||
|
value2: annotation2, |
||||
|
value3: annotation3, |
||||
|
) -> None: |
||||
|
await websocket.accept() |
||||
|
await websocket.send_json([value1, value2, value3]) |
||||
|
else: |
||||
|
|
||||
|
@router.post(path) |
||||
|
async def endpoint( |
||||
|
value1: annotation1, value2: annotation2, value3: annotation3 |
||||
|
) -> List[Any]: |
||||
|
return [value1, value2, value3] |
@ -0,0 +1,65 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
from docs_src.dependencies.tutorial013a import MyDatabaseConnection, app |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self): |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_records_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> List[dict]: |
||||
|
self.get_records_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_records(self, table_name) |
||||
|
return [] |
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mock(monkeypatch) -> MockDatabaseConnection: |
||||
|
mock = MockDatabaseConnection() |
||||
|
|
||||
|
monkeypatch.setattr(MyDatabaseConnection, "__new__", lambda *args, **kwargs: mock) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
|
||||
|
def test_dependency_usage(database_connection_mock): |
||||
|
assert database_connection_mock.enter_count == 0 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
with TestClient(app) as test_client: |
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
|
||||
|
response = test_client.get("/users") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
assert database_connection_mock.get_records_count == 1 |
||||
|
|
||||
|
response = test_client.get("/items") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
assert database_connection_mock.get_records_count == 2 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 1 |
@ -0,0 +1,70 @@ |
|||||
|
import sys |
||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
if sys.version_info >= (3, 9): |
||||
|
from docs_src.dependencies.tutorial013a_an_py39 import MyDatabaseConnection, app |
||||
|
|
||||
|
from ...utils import needs_py39 |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self): |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_records_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> List[dict]: |
||||
|
self.get_records_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_records(self, table_name) |
||||
|
return [] |
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mock(monkeypatch) -> MockDatabaseConnection: |
||||
|
mock = MockDatabaseConnection() |
||||
|
|
||||
|
monkeypatch.setattr(MyDatabaseConnection, "__new__", lambda *args, **kwargs: mock) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
|
||||
|
@needs_py39 |
||||
|
def test_dependency_usage(database_connection_mock): |
||||
|
assert database_connection_mock.enter_count == 0 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
with TestClient(app) as test_client: |
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
|
||||
|
response = test_client.get("/users") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
assert database_connection_mock.get_records_count == 1 |
||||
|
|
||||
|
response = test_client.get("/items") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
assert database_connection_mock.get_records_count == 2 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 1 |
@ -0,0 +1,130 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
from docs_src.dependencies.tutorial013b import MyDatabaseConnection, app |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self): |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_records_count = 0 |
||||
|
self.get_record_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> List[dict]: |
||||
|
self.get_records_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_records(self, table_name) |
||||
|
return [] |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
self.get_record_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_record(self, table_name, record_id) |
||||
|
return { |
||||
|
"table_name": table_name, |
||||
|
"record_id": record_id, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
||||
|
connections = [] |
||||
|
|
||||
|
def _get_new_connection_mock(*args, **kwargs): |
||||
|
mock = MockDatabaseConnection() |
||||
|
connections.append(mock) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) |
||||
|
return connections |
||||
|
|
||||
|
|
||||
|
def test_dependency_usage(database_connection_mocks): |
||||
|
assert len(database_connection_mocks) == 0 |
||||
|
|
||||
|
with TestClient(app) as test_client: |
||||
|
assert len(database_connection_mocks) == 3 |
||||
|
for connection in database_connection_mocks: |
||||
|
assert connection.enter_count == 1 |
||||
|
assert connection.exit_count == 0 |
||||
|
assert connection.get_records_count == 0 |
||||
|
assert connection.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get("/users") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
users_connection = None |
||||
|
for connection in database_connection_mocks: |
||||
|
if connection.get_records_count == 1: |
||||
|
users_connection = connection |
||||
|
break |
||||
|
|
||||
|
assert users_connection is not None, ( |
||||
|
"No connection was found for users endpoint" |
||||
|
) |
||||
|
|
||||
|
response = test_client.get("/groups") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
groups_connection = None |
||||
|
for connection in database_connection_mocks: |
||||
|
if connection.get_records_count == 1 and connection is not users_connection: |
||||
|
groups_connection = connection |
||||
|
break |
||||
|
|
||||
|
assert groups_connection is not None, ( |
||||
|
"No connection was found for groups endpoint" |
||||
|
) |
||||
|
assert groups_connection.get_records_count == 1 |
||||
|
|
||||
|
items_connection = None |
||||
|
for connection in database_connection_mocks: |
||||
|
if connection.get_records_count == 0: |
||||
|
items_connection = connection |
||||
|
break |
||||
|
|
||||
|
assert items_connection is not None, ( |
||||
|
"No connection was found for items endpoint" |
||||
|
) |
||||
|
|
||||
|
response = test_client.get("/items") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
assert items_connection.get_records_count == 1 |
||||
|
assert items_connection.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get("/items/asd") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == { |
||||
|
"table_name": "items", |
||||
|
"record_id": "asd", |
||||
|
} |
||||
|
|
||||
|
assert items_connection.get_records_count == 1 |
||||
|
assert items_connection.get_record_count == 1 |
||||
|
|
||||
|
for connection in database_connection_mocks: |
||||
|
assert connection.enter_count == 1 |
||||
|
assert connection.exit_count == 0 |
||||
|
|
||||
|
for connection in database_connection_mocks: |
||||
|
assert connection.enter_count == 1 |
||||
|
assert connection.exit_count == 1 |
@ -0,0 +1,135 @@ |
|||||
|
import sys |
||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
if sys.version_info >= (3, 9): |
||||
|
from docs_src.dependencies.tutorial013b_an_py39 import MyDatabaseConnection, app |
||||
|
|
||||
|
from ...utils import needs_py39 |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self): |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_records_count = 0 |
||||
|
self.get_record_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> List[dict]: |
||||
|
self.get_records_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_records(self, table_name) |
||||
|
return [] |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
self.get_record_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_record(self, table_name, record_id) |
||||
|
return { |
||||
|
"table_name": table_name, |
||||
|
"record_id": record_id, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
||||
|
connections = [] |
||||
|
|
||||
|
def _get_new_connection_mock(*args, **kwargs): |
||||
|
mock = MockDatabaseConnection() |
||||
|
connections.append(mock) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) |
||||
|
return connections |
||||
|
|
||||
|
|
||||
|
@needs_py39 |
||||
|
def test_dependency_usage(database_connection_mocks): |
||||
|
assert len(database_connection_mocks) == 0 |
||||
|
|
||||
|
with TestClient(app) as test_client: |
||||
|
assert len(database_connection_mocks) == 3 |
||||
|
for connection in database_connection_mocks: |
||||
|
assert connection.enter_count == 1 |
||||
|
assert connection.exit_count == 0 |
||||
|
assert connection.get_records_count == 0 |
||||
|
assert connection.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get("/users") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
users_connection = None |
||||
|
for connection in database_connection_mocks: |
||||
|
if connection.get_records_count == 1: |
||||
|
users_connection = connection |
||||
|
break |
||||
|
|
||||
|
assert users_connection is not None, ( |
||||
|
"No connection was found for users endpoint" |
||||
|
) |
||||
|
|
||||
|
response = test_client.get("/groups") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
groups_connection = None |
||||
|
for connection in database_connection_mocks: |
||||
|
if connection.get_records_count == 1 and connection is not users_connection: |
||||
|
groups_connection = connection |
||||
|
break |
||||
|
|
||||
|
assert groups_connection is not None, ( |
||||
|
"No connection was found for groups endpoint" |
||||
|
) |
||||
|
assert groups_connection.get_records_count == 1 |
||||
|
|
||||
|
items_connection = None |
||||
|
for connection in database_connection_mocks: |
||||
|
if connection.get_records_count == 0: |
||||
|
items_connection = connection |
||||
|
break |
||||
|
|
||||
|
assert items_connection is not None, ( |
||||
|
"No connection was found for items endpoint" |
||||
|
) |
||||
|
|
||||
|
response = test_client.get("/items") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
assert items_connection.get_records_count == 1 |
||||
|
assert items_connection.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get("/items/asd") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == { |
||||
|
"table_name": "items", |
||||
|
"record_id": "asd", |
||||
|
} |
||||
|
|
||||
|
assert items_connection.get_records_count == 1 |
||||
|
assert items_connection.get_record_count == 1 |
||||
|
|
||||
|
for connection in database_connection_mocks: |
||||
|
assert connection.enter_count == 1 |
||||
|
assert connection.exit_count == 0 |
||||
|
|
||||
|
for connection in database_connection_mocks: |
||||
|
assert connection.enter_count == 1 |
||||
|
assert connection.exit_count == 1 |
@ -0,0 +1,78 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
from docs_src.dependencies.tutorial013c import MyDatabaseConnection, app |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self, url: str): |
||||
|
self.url = url |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_record_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
self.get_record_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_record(self, table_name, record_id) |
||||
|
return { |
||||
|
"table_name": table_name, |
||||
|
"record_id": record_id, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
||||
|
connections = [] |
||||
|
|
||||
|
def _get_new_connection_mock(cls, url): |
||||
|
mock = MockDatabaseConnection(url) |
||||
|
connections.append(mock) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) |
||||
|
return connections |
||||
|
|
||||
|
|
||||
|
def test_dependency_usage(database_connection_mocks): |
||||
|
assert len(database_connection_mocks) == 0 |
||||
|
|
||||
|
with TestClient(app) as test_client: |
||||
|
assert len(database_connection_mocks) == 1 |
||||
|
[database_connection_mock] = database_connection_mocks |
||||
|
|
||||
|
assert database_connection_mock.url == "sqlite:///database.db" |
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get("/users/user") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == { |
||||
|
"table_name": "users", |
||||
|
"record_id": "user", |
||||
|
} |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 1 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert len(database_connection_mocks) == 1 |
@ -0,0 +1,83 @@ |
|||||
|
import sys |
||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
if sys.version_info >= (3, 9): |
||||
|
from docs_src.dependencies.tutorial013c_an_py39 import MyDatabaseConnection, app |
||||
|
|
||||
|
from ...utils import needs_py39 |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self, url: str): |
||||
|
self.url = url |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_record_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
self.get_record_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_record(self, table_name, record_id) |
||||
|
return { |
||||
|
"table_name": table_name, |
||||
|
"record_id": record_id, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
||||
|
connections = [] |
||||
|
|
||||
|
def _get_new_connection_mock(cls, url): |
||||
|
mock = MockDatabaseConnection(url) |
||||
|
connections.append(mock) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) |
||||
|
return connections |
||||
|
|
||||
|
|
||||
|
@needs_py39 |
||||
|
def test_dependency_usage(database_connection_mocks): |
||||
|
assert len(database_connection_mocks) == 0 |
||||
|
|
||||
|
with TestClient(app) as test_client: |
||||
|
assert len(database_connection_mocks) == 1 |
||||
|
[database_connection_mock] = database_connection_mocks |
||||
|
|
||||
|
assert database_connection_mock.url == "sqlite:///database.db" |
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get("/users/user") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == { |
||||
|
"table_name": "users", |
||||
|
"record_id": "user", |
||||
|
} |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 1 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert len(database_connection_mocks) == 1 |
@ -0,0 +1,76 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
from docs_src.dependencies.tutorial013d import MyDatabaseConnection, app |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self): |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_record_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
self.get_record_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_record(self, table_name, record_id) |
||||
|
return { |
||||
|
"table_name": table_name, |
||||
|
"record_id": record_id, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
||||
|
connections = [] |
||||
|
|
||||
|
def _get_new_connection_mock(*args, **kwargs): |
||||
|
mock = MockDatabaseConnection() |
||||
|
connections.append(mock) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) |
||||
|
return connections |
||||
|
|
||||
|
|
||||
|
def test_dependency_usage(database_connection_mocks): |
||||
|
assert len(database_connection_mocks) == 0 |
||||
|
|
||||
|
with TestClient(app) as test_client: |
||||
|
assert len(database_connection_mocks) == 1 |
||||
|
[database_connection_mock] = database_connection_mocks |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get("/users/user") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == { |
||||
|
"table_name": "users", |
||||
|
"record_id": "user", |
||||
|
} |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 1 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert len(database_connection_mocks) == 1 |
@ -0,0 +1,81 @@ |
|||||
|
import sys |
||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
if sys.version_info >= (3, 9): |
||||
|
from docs_src.dependencies.tutorial013d_an_py39 import MyDatabaseConnection, app |
||||
|
|
||||
|
from ...utils import needs_py39 |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self): |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_record_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
self.get_record_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_record(self, table_name, record_id) |
||||
|
return { |
||||
|
"table_name": table_name, |
||||
|
"record_id": record_id, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
||||
|
connections = [] |
||||
|
|
||||
|
def _get_new_connection_mock(*args, **kwargs): |
||||
|
mock = MockDatabaseConnection() |
||||
|
connections.append(mock) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock) |
||||
|
return connections |
||||
|
|
||||
|
|
||||
|
@needs_py39 |
||||
|
def test_dependency_usage(database_connection_mocks): |
||||
|
assert len(database_connection_mocks) == 0 |
||||
|
|
||||
|
with TestClient(app) as test_client: |
||||
|
assert len(database_connection_mocks) == 1 |
||||
|
[database_connection_mock] = database_connection_mocks |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get("/users/user") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == { |
||||
|
"table_name": "users", |
||||
|
"record_id": "user", |
||||
|
} |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 1 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert len(database_connection_mocks) == 1 |
Loading…
Reference in new issue