17 changed files with 1193 additions and 0 deletions
@ -0,0 +1,99 @@ |
|||||
|
# Lifespan Scoped Dependencies |
||||
|
|
||||
|
So far we've used dependencies which are "endpoint scoped". Meaning, they are |
||||
|
called again and again for every incoming request to the endpoint. However, |
||||
|
this is not ideal for all kinds of dependencies. |
||||
|
|
||||
|
Sometimes dependencies have a large setup/teardown time, or there is a need |
||||
|
for their value to be shared throughout the lifespan of the application. An |
||||
|
example of this would be a connection to a database. Databases are typically |
||||
|
less efficient when working with lots of connections and would prefer that |
||||
|
clients would create a single connection for their operations. |
||||
|
|
||||
|
For such cases, you might want to use "lifespan scoped" dependencies. |
||||
|
|
||||
|
## Intro |
||||
|
|
||||
|
Lifespan scoped dependencies work similarly to the dependencies we've worked |
||||
|
with so far (which are endpoint scoped). However, they are called once and only |
||||
|
once in the application's lifespan (instead of being called again and again for |
||||
|
every request). The returned value will be shared across all requests that need |
||||
|
it. |
||||
|
|
||||
|
|
||||
|
## Create a lifespan scoped dependency |
||||
|
|
||||
|
You may declare a dependency as a lifespan scoped dependency by passing |
||||
|
`dependency_scope="lifespan"` to the `Depends` function: |
||||
|
|
||||
|
{* ../../docs_src/dependencies/tutorial013a_an_py39.py hl[16] *} |
||||
|
|
||||
|
/// tip |
||||
|
|
||||
|
In the example above we saved the annotation to a separate variable, and then |
||||
|
reused it in our endpoints. This is not a requirement, we could also declare |
||||
|
the exact same annotation in both endpoints. However, it is recommended that you |
||||
|
do save the annotation to a variable so you won't accidentally forget to pass |
||||
|
`dependency_scope="lifespan"` to some of the endpoints (Causing the endpoint |
||||
|
to create a new database connection for every request). |
||||
|
|
||||
|
/// |
||||
|
|
||||
|
In this example, the `get_database_connection` dependency will be executed once, |
||||
|
during the application's startup. **FastAPI** will internally save the resulting |
||||
|
connection object, and whenever the `read_users` and `read_items` endpoints are |
||||
|
called, they will be using the previously saved connection. Once the application |
||||
|
shuts down, **FastAPI** will make sure to gracefully close the connection object. |
||||
|
|
||||
|
## The `use_cache` argument |
||||
|
|
||||
|
The `use_cache` argument works similarly to the way it worked with endpoint |
||||
|
scoped dependencies. Meaning as **FastAPI** gathers lifespan scoped dependencies, it |
||||
|
will cache dependencies it already encountered before. However, you can disable |
||||
|
this behavior by passing `use_cache=False` to `Depends`: |
||||
|
|
||||
|
{* ../../docs_src/dependencies/tutorial013b_an_py39.py hl[16] *} |
||||
|
|
||||
|
In this example, the `read_users` and `read_groups` endpoints are using |
||||
|
`use_cache=False` whereas the `read_items` and `read_item` are using |
||||
|
`use_cache=True`. That means that we'll have a total of 3 connections created |
||||
|
for the duration of the application's lifespan. One connection will be shared |
||||
|
across all requests for the `read_items` and `read_item` endpoints. A second |
||||
|
connection will be shared across all requests for the `read_users` endpoint. The |
||||
|
third and final connection will be shared across all requests for the |
||||
|
`read_groups` endpoint. |
||||
|
|
||||
|
|
||||
|
## Lifespan Scoped Sub-Dependencies |
||||
|
Just like with endpoint scoped dependencies, lifespan scoped dependencies may |
||||
|
use other lifespan scoped sub-dependencies themselves: |
||||
|
|
||||
|
{* ../../docs_src/dependencies/tutorial013c_an_py39.py hl[16] *} |
||||
|
|
||||
|
Endpoint scoped dependencies may use lifespan scoped sub dependencies as well: |
||||
|
|
||||
|
{* ../../docs_src/dependencies/tutorial013d_an_py39.py hl[16] *} |
||||
|
|
||||
|
/// note |
||||
|
|
||||
|
You can pass `dependency_scope="endpoint"` if you wish to explicitly specify |
||||
|
that a dependency is endpoint scoped. It will work the same as not specifying |
||||
|
a dependency scope at all. |
||||
|
|
||||
|
/// |
||||
|
|
||||
|
As you can see, regardless of the scope, dependencies can use lifespan scoped |
||||
|
sub-dependencies. |
||||
|
|
||||
|
## Dependency Scope Conflicts |
||||
|
By definition, lifespan scoped dependencies are being setup in the application's |
||||
|
startup process, before any request is ever being made to any endpoint. |
||||
|
Therefore, it is not possible for a lifespan scoped dependency to use any |
||||
|
parameters that require the scope of an endpoint. |
||||
|
|
||||
|
That includes but not limited to: |
||||
|
* Parts of the request (like `Body`, `Query` and `Path`) |
||||
|
* The request/response objects themselves (like `Request`, `Response` and `WebSocket`) |
||||
|
* Endpoint scoped sub-dependencies. |
||||
|
|
||||
|
Defining a dependency with such parameters will raise an `InvalidDependencyScope` error. |
@ -0,0 +1,39 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
from fastapi import Depends, FastAPI |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> List[dict]: |
||||
|
pass |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_database_connection(): |
||||
|
async with MyDatabaseConnection() as connection: |
||||
|
yield connection |
||||
|
|
||||
|
|
||||
|
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
||||
|
|
||||
|
|
||||
|
@app.get("/users/") |
||||
|
async def read_users(database_connection: MyDatabaseConnection = GlobalDatabaseConnection): |
||||
|
return await database_connection.get_records("users") |
||||
|
|
||||
|
|
||||
|
@app.get("/items/") |
||||
|
async def read_items(database_connection: MyDatabaseConnection = GlobalDatabaseConnection): |
||||
|
return await database_connection.get_records("items") |
@ -0,0 +1,38 @@ |
|||||
|
from typing import Annotated |
||||
|
|
||||
|
from fastapi import Depends, FastAPI |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> list[dict]: |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_database_connection(): |
||||
|
async with MyDatabaseConnection() as connection: |
||||
|
yield connection |
||||
|
|
||||
|
|
||||
|
GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")] |
||||
|
|
||||
|
@app.get("/users/") |
||||
|
async def read_users(database_connection: GlobalDatabaseConnection): |
||||
|
return await database_connection.get_records("users") |
||||
|
|
||||
|
|
||||
|
@app.get("/items/") |
||||
|
async def read_items(database_connection: GlobalDatabaseConnection): |
||||
|
return await database_connection.get_records("items") |
@ -0,0 +1,54 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, Path |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> List[dict]: |
||||
|
pass |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
pass |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_database_connection(): |
||||
|
async with MyDatabaseConnection() as connection: |
||||
|
yield connection |
||||
|
|
||||
|
|
||||
|
|
||||
|
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
||||
|
DedicatedDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan", use_cache=False) |
||||
|
|
||||
|
@app.get("/groups/") |
||||
|
async def read_groups(database_connection: MyDatabaseConnection = DedicatedDatabaseConnection): |
||||
|
return await database_connection.get_records("groups") |
||||
|
|
||||
|
@app.get("/users/") |
||||
|
async def read_users(database_connection: MyDatabaseConnection = DedicatedDatabaseConnection): |
||||
|
return await database_connection.get_records("users") |
||||
|
|
||||
|
|
||||
|
@app.get("/items/") |
||||
|
async def read_items(database_connection: MyDatabaseConnection = GlobalDatabaseConnection): |
||||
|
return await database_connection.get_records("items") |
||||
|
|
||||
|
@app.get("/items/{item_id}") |
||||
|
async def read_item( |
||||
|
item_id: str = Path(), |
||||
|
database_connection: MyDatabaseConnection = GlobalDatabaseConnection |
||||
|
): |
||||
|
return await database_connection.get_record("items", item_id) |
@ -0,0 +1,53 @@ |
|||||
|
from typing import Annotated |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, Path |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> list[dict]: |
||||
|
pass |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
pass |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_database_connection(): |
||||
|
async with MyDatabaseConnection() as connection: |
||||
|
yield connection |
||||
|
|
||||
|
|
||||
|
GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")] |
||||
|
DedicatedDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan", use_cache=False)] |
||||
|
|
||||
|
@app.get("/groups/") |
||||
|
async def read_groups(database_connection: DedicatedDatabaseConnection): |
||||
|
return await database_connection.get_records("groups") |
||||
|
|
||||
|
@app.get("/users/") |
||||
|
async def read_users(database_connection: DedicatedDatabaseConnection): |
||||
|
return await database_connection.get_records("users") |
||||
|
|
||||
|
|
||||
|
@app.get("/items/") |
||||
|
async def read_items(database_connection: GlobalDatabaseConnection): |
||||
|
return await database_connection.get_records("items") |
||||
|
|
||||
|
@app.get("/items/{item_id}") |
||||
|
async def read_item( |
||||
|
database_connection: GlobalDatabaseConnection, |
||||
|
item_id: Annotated[str, Path()] |
||||
|
): |
||||
|
return await database_connection.get_record("items", item_id) |
@ -0,0 +1,47 @@ |
|||||
|
from dataclasses import dataclass |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, Path |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
@dataclass |
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
connection_string: str |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
pass |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_configuration() -> dict: |
||||
|
return { |
||||
|
"database_url": "sqlite:///database.db", |
||||
|
} |
||||
|
|
||||
|
GlobalConfiguration = Depends(get_configuration, dependency_scope="lifespan") |
||||
|
|
||||
|
|
||||
|
async def get_database_connection( |
||||
|
configuration: dict = GlobalConfiguration |
||||
|
): |
||||
|
async with MyDatabaseConnection(configuration["database_url"]) as connection: |
||||
|
yield connection |
||||
|
|
||||
|
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
||||
|
|
||||
|
@app.get("/users/{user_id}") |
||||
|
async def read_user( |
||||
|
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
||||
|
user_id: str = Path() |
||||
|
): |
||||
|
return await database_connection.get_record("users", user_id) |
@ -0,0 +1,52 @@ |
|||||
|
from dataclasses import dataclass |
||||
|
from typing import Annotated, List |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, Path |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
@dataclass |
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
connection_string: str |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> List[dict]: |
||||
|
pass |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_configuration() -> dict: |
||||
|
return { |
||||
|
"database_url": "sqlite:///database.db", |
||||
|
} |
||||
|
|
||||
|
GlobalConfiguration = Annotated[dict, Depends(get_configuration, dependency_scope="lifespan")] |
||||
|
|
||||
|
|
||||
|
async def get_database_connection(configuration: GlobalConfiguration): |
||||
|
async with MyDatabaseConnection( |
||||
|
configuration["database_url"]) as connection: |
||||
|
yield connection |
||||
|
|
||||
|
GlobalDatabaseConnection = Annotated[get_database_connection, Depends(get_database_connection, dependency_scope="lifespan")] |
||||
|
|
||||
|
|
||||
|
@app.get("/users/{user_id}") |
||||
|
async def read_user( |
||||
|
database_connection: GlobalDatabaseConnection, |
||||
|
user_id: Annotated[str, Path()] |
||||
|
): |
||||
|
return await database_connection.get_record("users", user_id) |
@ -0,0 +1,46 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, Path |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> List[dict]: |
||||
|
pass |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
pass |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_database_connection(): |
||||
|
async with MyDatabaseConnection() as connection: |
||||
|
yield connection |
||||
|
|
||||
|
|
||||
|
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
||||
|
|
||||
|
|
||||
|
async def get_user_record( |
||||
|
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
||||
|
user_id: str = Path() |
||||
|
) -> dict: |
||||
|
return await database_connection.get_record("users", user_id) |
||||
|
|
||||
|
|
||||
|
@app.get("/users/{user_id}") |
||||
|
async def read_user( |
||||
|
user_record: dict = Depends(get_user_record) |
||||
|
): |
||||
|
return user_record |
@ -0,0 +1,45 @@ |
|||||
|
from typing import Annotated |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, Path |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
|
||||
|
class MyDatabaseConnection: |
||||
|
""" |
||||
|
This is a mock just for example purposes. |
||||
|
""" |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
return self |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
pass |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> list[dict]: |
||||
|
pass |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
pass |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_database_connection(): |
||||
|
async with MyDatabaseConnection() as connection: |
||||
|
yield connection |
||||
|
|
||||
|
|
||||
|
GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")] |
||||
|
|
||||
|
|
||||
|
async def get_user_record( |
||||
|
database_connection: GlobalDatabaseConnection, |
||||
|
user_id: Annotated[str, Path()] |
||||
|
) -> dict: |
||||
|
return await database_connection.get_record("users", user_id) |
||||
|
|
||||
|
@app.get("/users/{user_id}") |
||||
|
async def read_user( |
||||
|
user_record: Annotated[dict, Depends(get_user_record)] |
||||
|
): |
||||
|
return user_record |
@ -0,0 +1,69 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
from docs_src.dependencies.tutorial013a import MyDatabaseConnection, app |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self): |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_records_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> List[dict]: |
||||
|
self.get_records_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_records(self, table_name) |
||||
|
return [] |
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mock(monkeypatch) -> MockDatabaseConnection: |
||||
|
mock = MockDatabaseConnection() |
||||
|
|
||||
|
monkeypatch.setattr( |
||||
|
MyDatabaseConnection, |
||||
|
"__new__", |
||||
|
lambda *args, **kwargs: mock |
||||
|
) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
|
||||
|
def test_dependency_usage(database_connection_mock): |
||||
|
assert database_connection_mock.enter_count == 0 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
with TestClient(app) as test_client: |
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
|
||||
|
response = test_client.get('/users') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
assert database_connection_mock.get_records_count == 1 |
||||
|
|
||||
|
response = test_client.get('/items') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
assert database_connection_mock.get_records_count == 2 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 1 |
@ -0,0 +1,69 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
from docs_src.dependencies.tutorial013a_an_py39 import MyDatabaseConnection, app |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self): |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_records_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> List[dict]: |
||||
|
self.get_records_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_records(self, table_name) |
||||
|
return [] |
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mock(monkeypatch) -> MockDatabaseConnection: |
||||
|
mock = MockDatabaseConnection() |
||||
|
|
||||
|
monkeypatch.setattr( |
||||
|
MyDatabaseConnection, |
||||
|
"__new__", |
||||
|
lambda *args, **kwargs: mock |
||||
|
) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
|
||||
|
def test_dependency_usage(database_connection_mock): |
||||
|
assert database_connection_mock.enter_count == 0 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
with TestClient(app) as test_client: |
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
|
||||
|
response = test_client.get('/users') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
assert database_connection_mock.get_records_count == 1 |
||||
|
|
||||
|
response = test_client.get('/items') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
assert database_connection_mock.get_records_count == 2 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 1 |
@ -0,0 +1,129 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
from docs_src.dependencies.tutorial013b import MyDatabaseConnection, app |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self): |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_records_count = 0 |
||||
|
self.get_record_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> List[dict]: |
||||
|
self.get_records_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_records(self, table_name) |
||||
|
return [] |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
self.get_record_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_records(self, table_name) |
||||
|
return { |
||||
|
"table_name": table_name, |
||||
|
"record_id": record_id, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
||||
|
connections = [] |
||||
|
def _get_new_connection_mock(*args, **kwargs): |
||||
|
mock = MockDatabaseConnection() |
||||
|
connections.append(mock) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
|
||||
|
monkeypatch.setattr( |
||||
|
MyDatabaseConnection, |
||||
|
"__new__", |
||||
|
_get_new_connection_mock |
||||
|
) |
||||
|
return connections |
||||
|
|
||||
|
|
||||
|
def test_dependency_usage(database_connection_mocks): |
||||
|
assert len(database_connection_mocks) == 0 |
||||
|
|
||||
|
with TestClient(app) as test_client: |
||||
|
assert len(database_connection_mocks) == 3 |
||||
|
for connection in database_connection_mocks: |
||||
|
assert connection.enter_count == 1 |
||||
|
assert connection.exit_count == 0 |
||||
|
assert connection.get_records_count == 0 |
||||
|
assert connection.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get('/users') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
users_connection = None |
||||
|
for connection in database_connection_mocks: |
||||
|
if connection.get_records_count == 1: |
||||
|
users_connection = connection |
||||
|
break |
||||
|
|
||||
|
assert users_connection is not None, "No connection was found for users endpoint" |
||||
|
|
||||
|
response = test_client.get('/groups') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
groups_connection = None |
||||
|
for connection in database_connection_mocks: |
||||
|
if connection.get_records_count == 1 and connection is not users_connection: |
||||
|
groups_connection = connection |
||||
|
break |
||||
|
|
||||
|
assert groups_connection is not None, "No connection was found for groups endpoint" |
||||
|
assert groups_connection.get_records_count == 1 |
||||
|
|
||||
|
items_connection = None |
||||
|
for connection in database_connection_mocks: |
||||
|
if connection.get_records_count == 0: |
||||
|
items_connection = connection |
||||
|
break |
||||
|
|
||||
|
assert items_connection is not None, "No connection was found for items endpoint" |
||||
|
|
||||
|
response = test_client.get('/items') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
assert items_connection.get_records_count == 1 |
||||
|
assert items_connection.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get('/items/asd') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == { |
||||
|
"table_name": "items", |
||||
|
"record_id": "asd", |
||||
|
} |
||||
|
|
||||
|
assert items_connection.get_records_count == 1 |
||||
|
assert items_connection.get_record_count == 1 |
||||
|
|
||||
|
for connection in database_connection_mocks: |
||||
|
assert connection.enter_count == 1 |
||||
|
assert connection.exit_count == 0 |
||||
|
|
||||
|
for connection in database_connection_mocks: |
||||
|
assert connection.enter_count == 1 |
||||
|
assert connection.exit_count == 1 |
@ -0,0 +1,129 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
from docs_src.dependencies.tutorial013b_an_py39 import MyDatabaseConnection, app |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self): |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_records_count = 0 |
||||
|
self.get_record_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_records(self, table_name: str) -> List[dict]: |
||||
|
self.get_records_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_records(self, table_name) |
||||
|
return [] |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
self.get_record_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_records(self, table_name) |
||||
|
return { |
||||
|
"table_name": table_name, |
||||
|
"record_id": record_id, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
||||
|
connections = [] |
||||
|
def _get_new_connection_mock(*args, **kwargs): |
||||
|
mock = MockDatabaseConnection() |
||||
|
connections.append(mock) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
|
||||
|
monkeypatch.setattr( |
||||
|
MyDatabaseConnection, |
||||
|
"__new__", |
||||
|
_get_new_connection_mock |
||||
|
) |
||||
|
return connections |
||||
|
|
||||
|
|
||||
|
def test_dependency_usage(database_connection_mocks): |
||||
|
assert len(database_connection_mocks) == 0 |
||||
|
|
||||
|
with TestClient(app) as test_client: |
||||
|
assert len(database_connection_mocks) == 3 |
||||
|
for connection in database_connection_mocks: |
||||
|
assert connection.enter_count == 1 |
||||
|
assert connection.exit_count == 0 |
||||
|
assert connection.get_records_count == 0 |
||||
|
assert connection.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get('/users') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
users_connection = None |
||||
|
for connection in database_connection_mocks: |
||||
|
if connection.get_records_count == 1: |
||||
|
users_connection = connection |
||||
|
break |
||||
|
|
||||
|
assert users_connection is not None, "No connection was found for users endpoint" |
||||
|
|
||||
|
response = test_client.get('/groups') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
groups_connection = None |
||||
|
for connection in database_connection_mocks: |
||||
|
if connection.get_records_count == 1 and connection is not users_connection: |
||||
|
groups_connection = connection |
||||
|
break |
||||
|
|
||||
|
assert groups_connection is not None, "No connection was found for groups endpoint" |
||||
|
assert groups_connection.get_records_count == 1 |
||||
|
|
||||
|
items_connection = None |
||||
|
for connection in database_connection_mocks: |
||||
|
if connection.get_records_count == 0: |
||||
|
items_connection = connection |
||||
|
break |
||||
|
|
||||
|
assert items_connection is not None, "No connection was found for items endpoint" |
||||
|
|
||||
|
response = test_client.get('/items') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == [] |
||||
|
|
||||
|
assert items_connection.get_records_count == 1 |
||||
|
assert items_connection.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get('/items/asd') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == { |
||||
|
"table_name": "items", |
||||
|
"record_id": "asd", |
||||
|
} |
||||
|
|
||||
|
assert items_connection.get_records_count == 1 |
||||
|
assert items_connection.get_record_count == 1 |
||||
|
|
||||
|
for connection in database_connection_mocks: |
||||
|
assert connection.enter_count == 1 |
||||
|
assert connection.exit_count == 0 |
||||
|
|
||||
|
for connection in database_connection_mocks: |
||||
|
assert connection.enter_count == 1 |
||||
|
assert connection.exit_count == 1 |
@ -0,0 +1,82 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
from docs_src.dependencies.tutorial013c import MyDatabaseConnection, app |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self, url: str): |
||||
|
self.url = url |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_record_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
self.get_record_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_record(self, table_name, record_id) |
||||
|
return { |
||||
|
"table_name": table_name, |
||||
|
"record_id": record_id, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
||||
|
connections = [] |
||||
|
def _get_new_connection_mock(cls, url): |
||||
|
mock = MockDatabaseConnection(url) |
||||
|
connections.append(mock) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
monkeypatch.setattr( |
||||
|
MyDatabaseConnection, |
||||
|
"__new__", |
||||
|
_get_new_connection_mock |
||||
|
) |
||||
|
return connections |
||||
|
|
||||
|
|
||||
|
def test_dependency_usage(database_connection_mocks): |
||||
|
assert len(database_connection_mocks) == 0 |
||||
|
|
||||
|
with TestClient(app) as test_client: |
||||
|
assert len(database_connection_mocks) == 1 |
||||
|
[database_connection_mock] = database_connection_mocks |
||||
|
|
||||
|
assert database_connection_mock.url == "sqlite:///database.db" |
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get('/users/user') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == { |
||||
|
"table_name": "users", |
||||
|
"record_id": "user", |
||||
|
} |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 1 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert len(database_connection_mocks) == 1 |
@ -0,0 +1,82 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
from docs_src.dependencies.tutorial013c_an_py39 import MyDatabaseConnection, app |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self, url: str): |
||||
|
self.url = url |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_record_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
self.get_record_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_record(self, table_name, record_id) |
||||
|
return { |
||||
|
"table_name": table_name, |
||||
|
"record_id": record_id, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
||||
|
connections = [] |
||||
|
def _get_new_connection_mock(cls, url): |
||||
|
mock = MockDatabaseConnection(url) |
||||
|
connections.append(mock) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
monkeypatch.setattr( |
||||
|
MyDatabaseConnection, |
||||
|
"__new__", |
||||
|
_get_new_connection_mock |
||||
|
) |
||||
|
return connections |
||||
|
|
||||
|
|
||||
|
def test_dependency_usage(database_connection_mocks): |
||||
|
assert len(database_connection_mocks) == 0 |
||||
|
|
||||
|
with TestClient(app) as test_client: |
||||
|
assert len(database_connection_mocks) == 1 |
||||
|
[database_connection_mock] = database_connection_mocks |
||||
|
|
||||
|
assert database_connection_mock.url == "sqlite:///database.db" |
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get('/users/user') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == { |
||||
|
"table_name": "users", |
||||
|
"record_id": "user", |
||||
|
} |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 1 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert len(database_connection_mocks) == 1 |
@ -0,0 +1,80 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
from docs_src.dependencies.tutorial013d import MyDatabaseConnection, app |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self): |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_record_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
self.get_record_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_record(self, table_name, record_id) |
||||
|
return { |
||||
|
"table_name": table_name, |
||||
|
"record_id": record_id, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
||||
|
connections = [] |
||||
|
def _get_new_connection_mock(*args, **kwargs): |
||||
|
mock = MockDatabaseConnection() |
||||
|
connections.append(mock) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
monkeypatch.setattr( |
||||
|
MyDatabaseConnection, |
||||
|
"__new__", |
||||
|
_get_new_connection_mock |
||||
|
) |
||||
|
return connections |
||||
|
|
||||
|
|
||||
|
def test_dependency_usage(database_connection_mocks): |
||||
|
assert len(database_connection_mocks) == 0 |
||||
|
|
||||
|
with TestClient(app) as test_client: |
||||
|
assert len(database_connection_mocks) == 1 |
||||
|
[database_connection_mock] = database_connection_mocks |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get('/users/user') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == { |
||||
|
"table_name": "users", |
||||
|
"record_id": "user", |
||||
|
} |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 1 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert len(database_connection_mocks) == 1 |
@ -0,0 +1,80 @@ |
|||||
|
from typing import List |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.testclient import TestClient |
||||
|
from typing_extensions import Self |
||||
|
|
||||
|
from docs_src.dependencies.tutorial013d_an_py39 import MyDatabaseConnection, app |
||||
|
|
||||
|
|
||||
|
class MockDatabaseConnection: |
||||
|
def __init__(self): |
||||
|
self.enter_count = 0 |
||||
|
self.exit_count = 0 |
||||
|
self.get_record_count = 0 |
||||
|
|
||||
|
async def __aenter__(self) -> Self: |
||||
|
self.enter_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aenter__(self) |
||||
|
|
||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
|
self.exit_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
||||
|
|
||||
|
async def get_record(self, table_name: str, record_id: str) -> dict: |
||||
|
self.get_record_count += 1 |
||||
|
# Called for the sake of coverage. |
||||
|
await MyDatabaseConnection.get_record(self, table_name, record_id) |
||||
|
return { |
||||
|
"table_name": table_name, |
||||
|
"record_id": record_id, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
|
||||
|
@pytest.fixture |
||||
|
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
||||
|
connections = [] |
||||
|
def _get_new_connection_mock(*args, **kwargs): |
||||
|
mock = MockDatabaseConnection() |
||||
|
connections.append(mock) |
||||
|
|
||||
|
return mock |
||||
|
|
||||
|
monkeypatch.setattr( |
||||
|
MyDatabaseConnection, |
||||
|
"__new__", |
||||
|
_get_new_connection_mock |
||||
|
) |
||||
|
return connections |
||||
|
|
||||
|
|
||||
|
def test_dependency_usage(database_connection_mocks): |
||||
|
assert len(database_connection_mocks) == 0 |
||||
|
|
||||
|
with TestClient(app) as test_client: |
||||
|
assert len(database_connection_mocks) == 1 |
||||
|
[database_connection_mock] = database_connection_mocks |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 0 |
||||
|
|
||||
|
response = test_client.get('/users/user') |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == { |
||||
|
"table_name": "users", |
||||
|
"record_id": "user", |
||||
|
} |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 0 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert database_connection_mock.enter_count == 1 |
||||
|
assert database_connection_mock.exit_count == 1 |
||||
|
assert database_connection_mock.get_record_count == 1 |
||||
|
|
||||
|
assert len(database_connection_mocks) == 1 |
Loading…
Reference in new issue