17 changed files with 1193 additions and 0 deletions
@ -0,0 +1,99 @@ |
|||
# Lifespan Scoped Dependencies |
|||
|
|||
So far we've used dependencies which are "endpoint scoped". Meaning, they are |
|||
called again and again for every incoming request to the endpoint. However, |
|||
this is not ideal for all kinds of dependencies. |
|||
|
|||
Sometimes dependencies have a large setup/teardown time, or there is a need |
|||
for their value to be shared throughout the lifespan of the application. An |
|||
example of this would be a connection to a database. Databases are typically |
|||
less efficient when working with lots of connections and would prefer that |
|||
clients would create a single connection for their operations. |
|||
|
|||
For such cases, you might want to use "lifespan scoped" dependencies. |
|||
|
|||
## Intro |
|||
|
|||
Lifespan scoped dependencies work similarly to the dependencies we've worked |
|||
with so far (which are endpoint scoped). However, they are called once and only |
|||
once in the application's lifespan (instead of being called again and again for |
|||
every request). The returned value will be shared across all requests that need |
|||
it. |
|||
|
|||
|
|||
## Create a lifespan scoped dependency |
|||
|
|||
You may declare a dependency as a lifespan scoped dependency by passing |
|||
`dependency_scope="lifespan"` to the `Depends` function: |
|||
|
|||
{* ../../docs_src/dependencies/tutorial013a_an_py39.py hl[16] *} |
|||
|
|||
/// tip |
|||
|
|||
In the example above we saved the annotation to a separate variable, and then |
|||
reused it in our endpoints. This is not a requirement, we could also declare |
|||
the exact same annotation in both endpoints. However, it is recommended that you |
|||
do save the annotation to a variable so you won't accidentally forget to pass |
|||
`dependency_scope="lifespan"` to some of the endpoints (Causing the endpoint |
|||
to create a new database connection for every request). |
|||
|
|||
/// |
|||
|
|||
In this example, the `get_database_connection` dependency will be executed once, |
|||
during the application's startup. **FastAPI** will internally save the resulting |
|||
connection object, and whenever the `read_users` and `read_items` endpoints are |
|||
called, they will be using the previously saved connection. Once the application |
|||
shuts down, **FastAPI** will make sure to gracefully close the connection object. |
|||
|
|||
## The `use_cache` argument |
|||
|
|||
The `use_cache` argument works similarly to the way it worked with endpoint |
|||
scoped dependencies. Meaning as **FastAPI** gathers lifespan scoped dependencies, it |
|||
will cache dependencies it already encountered before. However, you can disable |
|||
this behavior by passing `use_cache=False` to `Depends`: |
|||
|
|||
{* ../../docs_src/dependencies/tutorial013b_an_py39.py hl[16] *} |
|||
|
|||
In this example, the `read_users` and `read_groups` endpoints are using |
|||
`use_cache=False` whereas the `read_items` and `read_item` are using |
|||
`use_cache=True`. That means that we'll have a total of 3 connections created |
|||
for the duration of the application's lifespan. One connection will be shared |
|||
across all requests for the `read_items` and `read_item` endpoints. A second |
|||
connection will be shared across all requests for the `read_users` endpoint. The |
|||
third and final connection will be shared across all requests for the |
|||
`read_groups` endpoint. |
|||
|
|||
|
|||
## Lifespan Scoped Sub-Dependencies |
|||
Just like with endpoint scoped dependencies, lifespan scoped dependencies may |
|||
use other lifespan scoped sub-dependencies themselves: |
|||
|
|||
{* ../../docs_src/dependencies/tutorial013c_an_py39.py hl[16] *} |
|||
|
|||
Endpoint scoped dependencies may use lifespan scoped sub dependencies as well: |
|||
|
|||
{* ../../docs_src/dependencies/tutorial013d_an_py39.py hl[16] *} |
|||
|
|||
/// note |
|||
|
|||
You can pass `dependency_scope="endpoint"` if you wish to explicitly specify |
|||
that a dependency is endpoint scoped. It will work the same as not specifying |
|||
a dependency scope at all. |
|||
|
|||
/// |
|||
|
|||
As you can see, regardless of the scope, dependencies can use lifespan scoped |
|||
sub-dependencies. |
|||
|
|||
## Dependency Scope Conflicts |
|||
By definition, lifespan scoped dependencies are being setup in the application's |
|||
startup process, before any request is ever being made to any endpoint. |
|||
Therefore, it is not possible for a lifespan scoped dependency to use any |
|||
parameters that require the scope of an endpoint. |
|||
|
|||
That includes but not limited to: |
|||
* Parts of the request (like `Body`, `Query` and `Path`) |
|||
* The request/response objects themselves (like `Request`, `Response` and `WebSocket`) |
|||
* Endpoint scoped sub-dependencies. |
|||
|
|||
Defining a dependency with such parameters will raise an `InvalidDependencyScope` error. |
@ -0,0 +1,39 @@ |
|||
from typing import List |
|||
|
|||
from fastapi import Depends, FastAPI |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_records(self, table_name: str) -> List[dict]: |
|||
pass |
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_database_connection(): |
|||
async with MyDatabaseConnection() as connection: |
|||
yield connection |
|||
|
|||
|
|||
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
|||
|
|||
|
|||
@app.get("/users/") |
|||
async def read_users(database_connection: MyDatabaseConnection = GlobalDatabaseConnection): |
|||
return await database_connection.get_records("users") |
|||
|
|||
|
|||
@app.get("/items/") |
|||
async def read_items(database_connection: MyDatabaseConnection = GlobalDatabaseConnection): |
|||
return await database_connection.get_records("items") |
@ -0,0 +1,38 @@ |
|||
from typing import Annotated |
|||
|
|||
from fastapi import Depends, FastAPI |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_records(self, table_name: str) -> list[dict]: |
|||
pass |
|||
|
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_database_connection(): |
|||
async with MyDatabaseConnection() as connection: |
|||
yield connection |
|||
|
|||
|
|||
GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")] |
|||
|
|||
@app.get("/users/") |
|||
async def read_users(database_connection: GlobalDatabaseConnection): |
|||
return await database_connection.get_records("users") |
|||
|
|||
|
|||
@app.get("/items/") |
|||
async def read_items(database_connection: GlobalDatabaseConnection): |
|||
return await database_connection.get_records("items") |
@ -0,0 +1,54 @@ |
|||
from typing import List |
|||
|
|||
from fastapi import Depends, FastAPI, Path |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_records(self, table_name: str) -> List[dict]: |
|||
pass |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
pass |
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_database_connection(): |
|||
async with MyDatabaseConnection() as connection: |
|||
yield connection |
|||
|
|||
|
|||
|
|||
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
|||
DedicatedDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan", use_cache=False) |
|||
|
|||
@app.get("/groups/") |
|||
async def read_groups(database_connection: MyDatabaseConnection = DedicatedDatabaseConnection): |
|||
return await database_connection.get_records("groups") |
|||
|
|||
@app.get("/users/") |
|||
async def read_users(database_connection: MyDatabaseConnection = DedicatedDatabaseConnection): |
|||
return await database_connection.get_records("users") |
|||
|
|||
|
|||
@app.get("/items/") |
|||
async def read_items(database_connection: MyDatabaseConnection = GlobalDatabaseConnection): |
|||
return await database_connection.get_records("items") |
|||
|
|||
@app.get("/items/{item_id}") |
|||
async def read_item( |
|||
item_id: str = Path(), |
|||
database_connection: MyDatabaseConnection = GlobalDatabaseConnection |
|||
): |
|||
return await database_connection.get_record("items", item_id) |
@ -0,0 +1,53 @@ |
|||
from typing import Annotated |
|||
|
|||
from fastapi import Depends, FastAPI, Path |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_records(self, table_name: str) -> list[dict]: |
|||
pass |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
pass |
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_database_connection(): |
|||
async with MyDatabaseConnection() as connection: |
|||
yield connection |
|||
|
|||
|
|||
GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")] |
|||
DedicatedDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan", use_cache=False)] |
|||
|
|||
@app.get("/groups/") |
|||
async def read_groups(database_connection: DedicatedDatabaseConnection): |
|||
return await database_connection.get_records("groups") |
|||
|
|||
@app.get("/users/") |
|||
async def read_users(database_connection: DedicatedDatabaseConnection): |
|||
return await database_connection.get_records("users") |
|||
|
|||
|
|||
@app.get("/items/") |
|||
async def read_items(database_connection: GlobalDatabaseConnection): |
|||
return await database_connection.get_records("items") |
|||
|
|||
@app.get("/items/{item_id}") |
|||
async def read_item( |
|||
database_connection: GlobalDatabaseConnection, |
|||
item_id: Annotated[str, Path()] |
|||
): |
|||
return await database_connection.get_record("items", item_id) |
@ -0,0 +1,47 @@ |
|||
from dataclasses import dataclass |
|||
|
|||
from fastapi import Depends, FastAPI, Path |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
@dataclass |
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
connection_string: str |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
pass |
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_configuration() -> dict: |
|||
return { |
|||
"database_url": "sqlite:///database.db", |
|||
} |
|||
|
|||
GlobalConfiguration = Depends(get_configuration, dependency_scope="lifespan") |
|||
|
|||
|
|||
async def get_database_connection( |
|||
configuration: dict = GlobalConfiguration |
|||
): |
|||
async with MyDatabaseConnection(configuration["database_url"]) as connection: |
|||
yield connection |
|||
|
|||
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
|||
|
|||
@app.get("/users/{user_id}") |
|||
async def read_user( |
|||
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
|||
user_id: str = Path() |
|||
): |
|||
return await database_connection.get_record("users", user_id) |
@ -0,0 +1,52 @@ |
|||
from dataclasses import dataclass |
|||
from typing import Annotated, List |
|||
|
|||
from fastapi import Depends, FastAPI, Path |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
@dataclass |
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
connection_string: str |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_records(self, table_name: str) -> List[dict]: |
|||
pass |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
pass |
|||
|
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_configuration() -> dict: |
|||
return { |
|||
"database_url": "sqlite:///database.db", |
|||
} |
|||
|
|||
GlobalConfiguration = Annotated[dict, Depends(get_configuration, dependency_scope="lifespan")] |
|||
|
|||
|
|||
async def get_database_connection(configuration: GlobalConfiguration): |
|||
async with MyDatabaseConnection( |
|||
configuration["database_url"]) as connection: |
|||
yield connection |
|||
|
|||
GlobalDatabaseConnection = Annotated[get_database_connection, Depends(get_database_connection, dependency_scope="lifespan")] |
|||
|
|||
|
|||
@app.get("/users/{user_id}") |
|||
async def read_user( |
|||
database_connection: GlobalDatabaseConnection, |
|||
user_id: Annotated[str, Path()] |
|||
): |
|||
return await database_connection.get_record("users", user_id) |
@ -0,0 +1,46 @@ |
|||
from typing import List |
|||
|
|||
from fastapi import Depends, FastAPI, Path |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_records(self, table_name: str) -> List[dict]: |
|||
pass |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
pass |
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_database_connection(): |
|||
async with MyDatabaseConnection() as connection: |
|||
yield connection |
|||
|
|||
|
|||
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") |
|||
|
|||
|
|||
async def get_user_record( |
|||
database_connection: MyDatabaseConnection = GlobalDatabaseConnection, |
|||
user_id: str = Path() |
|||
) -> dict: |
|||
return await database_connection.get_record("users", user_id) |
|||
|
|||
|
|||
@app.get("/users/{user_id}") |
|||
async def read_user( |
|||
user_record: dict = Depends(get_user_record) |
|||
): |
|||
return user_record |
@ -0,0 +1,45 @@ |
|||
from typing import Annotated |
|||
|
|||
from fastapi import Depends, FastAPI, Path |
|||
from typing_extensions import Self |
|||
|
|||
|
|||
class MyDatabaseConnection: |
|||
""" |
|||
This is a mock just for example purposes. |
|||
""" |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
return self |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
pass |
|||
|
|||
async def get_records(self, table_name: str) -> list[dict]: |
|||
pass |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
pass |
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_database_connection(): |
|||
async with MyDatabaseConnection() as connection: |
|||
yield connection |
|||
|
|||
|
|||
GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")] |
|||
|
|||
|
|||
async def get_user_record( |
|||
database_connection: GlobalDatabaseConnection, |
|||
user_id: Annotated[str, Path()] |
|||
) -> dict: |
|||
return await database_connection.get_record("users", user_id) |
|||
|
|||
@app.get("/users/{user_id}") |
|||
async def read_user( |
|||
user_record: Annotated[dict, Depends(get_user_record)] |
|||
): |
|||
return user_record |
@ -0,0 +1,69 @@ |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
from docs_src.dependencies.tutorial013a import MyDatabaseConnection, app |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self): |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_records_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_records(self, table_name: str) -> List[dict]: |
|||
self.get_records_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_records(self, table_name) |
|||
return [] |
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mock(monkeypatch) -> MockDatabaseConnection: |
|||
mock = MockDatabaseConnection() |
|||
|
|||
monkeypatch.setattr( |
|||
MyDatabaseConnection, |
|||
"__new__", |
|||
lambda *args, **kwargs: mock |
|||
) |
|||
|
|||
return mock |
|||
|
|||
|
|||
def test_dependency_usage(database_connection_mock): |
|||
assert database_connection_mock.enter_count == 0 |
|||
assert database_connection_mock.exit_count == 0 |
|||
with TestClient(app) as test_client: |
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
|
|||
response = test_client.get('/users') |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
assert database_connection_mock.get_records_count == 1 |
|||
|
|||
response = test_client.get('/items') |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
assert database_connection_mock.get_records_count == 2 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 1 |
@ -0,0 +1,69 @@ |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
from docs_src.dependencies.tutorial013a_an_py39 import MyDatabaseConnection, app |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self): |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_records_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_records(self, table_name: str) -> List[dict]: |
|||
self.get_records_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_records(self, table_name) |
|||
return [] |
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mock(monkeypatch) -> MockDatabaseConnection: |
|||
mock = MockDatabaseConnection() |
|||
|
|||
monkeypatch.setattr( |
|||
MyDatabaseConnection, |
|||
"__new__", |
|||
lambda *args, **kwargs: mock |
|||
) |
|||
|
|||
return mock |
|||
|
|||
|
|||
def test_dependency_usage(database_connection_mock): |
|||
assert database_connection_mock.enter_count == 0 |
|||
assert database_connection_mock.exit_count == 0 |
|||
with TestClient(app) as test_client: |
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
|
|||
response = test_client.get('/users') |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
assert database_connection_mock.get_records_count == 1 |
|||
|
|||
response = test_client.get('/items') |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
assert database_connection_mock.get_records_count == 2 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 1 |
@ -0,0 +1,129 @@ |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
from docs_src.dependencies.tutorial013b import MyDatabaseConnection, app |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self): |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_records_count = 0 |
|||
self.get_record_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_records(self, table_name: str) -> List[dict]: |
|||
self.get_records_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_records(self, table_name) |
|||
return [] |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
self.get_record_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_records(self, table_name) |
|||
return { |
|||
"table_name": table_name, |
|||
"record_id": record_id, |
|||
} |
|||
|
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
|||
connections = [] |
|||
def _get_new_connection_mock(*args, **kwargs): |
|||
mock = MockDatabaseConnection() |
|||
connections.append(mock) |
|||
|
|||
return mock |
|||
|
|||
|
|||
monkeypatch.setattr( |
|||
MyDatabaseConnection, |
|||
"__new__", |
|||
_get_new_connection_mock |
|||
) |
|||
return connections |
|||
|
|||
|
|||
def test_dependency_usage(database_connection_mocks): |
|||
assert len(database_connection_mocks) == 0 |
|||
|
|||
with TestClient(app) as test_client: |
|||
assert len(database_connection_mocks) == 3 |
|||
for connection in database_connection_mocks: |
|||
assert connection.enter_count == 1 |
|||
assert connection.exit_count == 0 |
|||
assert connection.get_records_count == 0 |
|||
assert connection.get_record_count == 0 |
|||
|
|||
response = test_client.get('/users') |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
users_connection = None |
|||
for connection in database_connection_mocks: |
|||
if connection.get_records_count == 1: |
|||
users_connection = connection |
|||
break |
|||
|
|||
assert users_connection is not None, "No connection was found for users endpoint" |
|||
|
|||
response = test_client.get('/groups') |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
groups_connection = None |
|||
for connection in database_connection_mocks: |
|||
if connection.get_records_count == 1 and connection is not users_connection: |
|||
groups_connection = connection |
|||
break |
|||
|
|||
assert groups_connection is not None, "No connection was found for groups endpoint" |
|||
assert groups_connection.get_records_count == 1 |
|||
|
|||
items_connection = None |
|||
for connection in database_connection_mocks: |
|||
if connection.get_records_count == 0: |
|||
items_connection = connection |
|||
break |
|||
|
|||
assert items_connection is not None, "No connection was found for items endpoint" |
|||
|
|||
response = test_client.get('/items') |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
assert items_connection.get_records_count == 1 |
|||
assert items_connection.get_record_count == 0 |
|||
|
|||
response = test_client.get('/items/asd') |
|||
assert response.status_code == 200 |
|||
assert response.json() == { |
|||
"table_name": "items", |
|||
"record_id": "asd", |
|||
} |
|||
|
|||
assert items_connection.get_records_count == 1 |
|||
assert items_connection.get_record_count == 1 |
|||
|
|||
for connection in database_connection_mocks: |
|||
assert connection.enter_count == 1 |
|||
assert connection.exit_count == 0 |
|||
|
|||
for connection in database_connection_mocks: |
|||
assert connection.enter_count == 1 |
|||
assert connection.exit_count == 1 |
@ -0,0 +1,129 @@ |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
from docs_src.dependencies.tutorial013b_an_py39 import MyDatabaseConnection, app |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self): |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_records_count = 0 |
|||
self.get_record_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_records(self, table_name: str) -> List[dict]: |
|||
self.get_records_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_records(self, table_name) |
|||
return [] |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
self.get_record_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_records(self, table_name) |
|||
return { |
|||
"table_name": table_name, |
|||
"record_id": record_id, |
|||
} |
|||
|
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
|||
connections = [] |
|||
def _get_new_connection_mock(*args, **kwargs): |
|||
mock = MockDatabaseConnection() |
|||
connections.append(mock) |
|||
|
|||
return mock |
|||
|
|||
|
|||
monkeypatch.setattr( |
|||
MyDatabaseConnection, |
|||
"__new__", |
|||
_get_new_connection_mock |
|||
) |
|||
return connections |
|||
|
|||
|
|||
def test_dependency_usage(database_connection_mocks): |
|||
assert len(database_connection_mocks) == 0 |
|||
|
|||
with TestClient(app) as test_client: |
|||
assert len(database_connection_mocks) == 3 |
|||
for connection in database_connection_mocks: |
|||
assert connection.enter_count == 1 |
|||
assert connection.exit_count == 0 |
|||
assert connection.get_records_count == 0 |
|||
assert connection.get_record_count == 0 |
|||
|
|||
response = test_client.get('/users') |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
users_connection = None |
|||
for connection in database_connection_mocks: |
|||
if connection.get_records_count == 1: |
|||
users_connection = connection |
|||
break |
|||
|
|||
assert users_connection is not None, "No connection was found for users endpoint" |
|||
|
|||
response = test_client.get('/groups') |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
groups_connection = None |
|||
for connection in database_connection_mocks: |
|||
if connection.get_records_count == 1 and connection is not users_connection: |
|||
groups_connection = connection |
|||
break |
|||
|
|||
assert groups_connection is not None, "No connection was found for groups endpoint" |
|||
assert groups_connection.get_records_count == 1 |
|||
|
|||
items_connection = None |
|||
for connection in database_connection_mocks: |
|||
if connection.get_records_count == 0: |
|||
items_connection = connection |
|||
break |
|||
|
|||
assert items_connection is not None, "No connection was found for items endpoint" |
|||
|
|||
response = test_client.get('/items') |
|||
assert response.status_code == 200 |
|||
assert response.json() == [] |
|||
|
|||
assert items_connection.get_records_count == 1 |
|||
assert items_connection.get_record_count == 0 |
|||
|
|||
response = test_client.get('/items/asd') |
|||
assert response.status_code == 200 |
|||
assert response.json() == { |
|||
"table_name": "items", |
|||
"record_id": "asd", |
|||
} |
|||
|
|||
assert items_connection.get_records_count == 1 |
|||
assert items_connection.get_record_count == 1 |
|||
|
|||
for connection in database_connection_mocks: |
|||
assert connection.enter_count == 1 |
|||
assert connection.exit_count == 0 |
|||
|
|||
for connection in database_connection_mocks: |
|||
assert connection.enter_count == 1 |
|||
assert connection.exit_count == 1 |
@ -0,0 +1,82 @@ |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
from docs_src.dependencies.tutorial013c import MyDatabaseConnection, app |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self, url: str): |
|||
self.url = url |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_record_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
self.get_record_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_record(self, table_name, record_id) |
|||
return { |
|||
"table_name": table_name, |
|||
"record_id": record_id, |
|||
} |
|||
|
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
|||
connections = [] |
|||
def _get_new_connection_mock(cls, url): |
|||
mock = MockDatabaseConnection(url) |
|||
connections.append(mock) |
|||
|
|||
return mock |
|||
|
|||
monkeypatch.setattr( |
|||
MyDatabaseConnection, |
|||
"__new__", |
|||
_get_new_connection_mock |
|||
) |
|||
return connections |
|||
|
|||
|
|||
def test_dependency_usage(database_connection_mocks): |
|||
assert len(database_connection_mocks) == 0 |
|||
|
|||
with TestClient(app) as test_client: |
|||
assert len(database_connection_mocks) == 1 |
|||
[database_connection_mock] = database_connection_mocks |
|||
|
|||
assert database_connection_mock.url == "sqlite:///database.db" |
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 0 |
|||
|
|||
response = test_client.get('/users/user') |
|||
assert response.status_code == 200 |
|||
assert response.json() == { |
|||
"table_name": "users", |
|||
"record_id": "user", |
|||
} |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 1 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert len(database_connection_mocks) == 1 |
@ -0,0 +1,82 @@ |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
from docs_src.dependencies.tutorial013c_an_py39 import MyDatabaseConnection, app |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self, url: str): |
|||
self.url = url |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_record_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
self.get_record_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_record(self, table_name, record_id) |
|||
return { |
|||
"table_name": table_name, |
|||
"record_id": record_id, |
|||
} |
|||
|
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
|||
connections = [] |
|||
def _get_new_connection_mock(cls, url): |
|||
mock = MockDatabaseConnection(url) |
|||
connections.append(mock) |
|||
|
|||
return mock |
|||
|
|||
monkeypatch.setattr( |
|||
MyDatabaseConnection, |
|||
"__new__", |
|||
_get_new_connection_mock |
|||
) |
|||
return connections |
|||
|
|||
|
|||
def test_dependency_usage(database_connection_mocks): |
|||
assert len(database_connection_mocks) == 0 |
|||
|
|||
with TestClient(app) as test_client: |
|||
assert len(database_connection_mocks) == 1 |
|||
[database_connection_mock] = database_connection_mocks |
|||
|
|||
assert database_connection_mock.url == "sqlite:///database.db" |
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 0 |
|||
|
|||
response = test_client.get('/users/user') |
|||
assert response.status_code == 200 |
|||
assert response.json() == { |
|||
"table_name": "users", |
|||
"record_id": "user", |
|||
} |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 1 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert len(database_connection_mocks) == 1 |
@ -0,0 +1,80 @@ |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
from docs_src.dependencies.tutorial013d import MyDatabaseConnection, app |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self): |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_record_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
self.get_record_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_record(self, table_name, record_id) |
|||
return { |
|||
"table_name": table_name, |
|||
"record_id": record_id, |
|||
} |
|||
|
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
|||
connections = [] |
|||
def _get_new_connection_mock(*args, **kwargs): |
|||
mock = MockDatabaseConnection() |
|||
connections.append(mock) |
|||
|
|||
return mock |
|||
|
|||
monkeypatch.setattr( |
|||
MyDatabaseConnection, |
|||
"__new__", |
|||
_get_new_connection_mock |
|||
) |
|||
return connections |
|||
|
|||
|
|||
def test_dependency_usage(database_connection_mocks): |
|||
assert len(database_connection_mocks) == 0 |
|||
|
|||
with TestClient(app) as test_client: |
|||
assert len(database_connection_mocks) == 1 |
|||
[database_connection_mock] = database_connection_mocks |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 0 |
|||
|
|||
response = test_client.get('/users/user') |
|||
assert response.status_code == 200 |
|||
assert response.json() == { |
|||
"table_name": "users", |
|||
"record_id": "user", |
|||
} |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 1 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert len(database_connection_mocks) == 1 |
@ -0,0 +1,80 @@ |
|||
from typing import List |
|||
|
|||
import pytest |
|||
from starlette.testclient import TestClient |
|||
from typing_extensions import Self |
|||
|
|||
from docs_src.dependencies.tutorial013d_an_py39 import MyDatabaseConnection, app |
|||
|
|||
|
|||
class MockDatabaseConnection: |
|||
def __init__(self): |
|||
self.enter_count = 0 |
|||
self.exit_count = 0 |
|||
self.get_record_count = 0 |
|||
|
|||
async def __aenter__(self) -> Self: |
|||
self.enter_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aenter__(self) |
|||
|
|||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|||
self.exit_count += 1 |
|||
# Called for the sake of coverage. |
|||
return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) |
|||
|
|||
async def get_record(self, table_name: str, record_id: str) -> dict: |
|||
self.get_record_count += 1 |
|||
# Called for the sake of coverage. |
|||
await MyDatabaseConnection.get_record(self, table_name, record_id) |
|||
return { |
|||
"table_name": table_name, |
|||
"record_id": record_id, |
|||
} |
|||
|
|||
|
|||
|
|||
@pytest.fixture |
|||
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: |
|||
connections = [] |
|||
def _get_new_connection_mock(*args, **kwargs): |
|||
mock = MockDatabaseConnection() |
|||
connections.append(mock) |
|||
|
|||
return mock |
|||
|
|||
monkeypatch.setattr( |
|||
MyDatabaseConnection, |
|||
"__new__", |
|||
_get_new_connection_mock |
|||
) |
|||
return connections |
|||
|
|||
|
|||
def test_dependency_usage(database_connection_mocks): |
|||
assert len(database_connection_mocks) == 0 |
|||
|
|||
with TestClient(app) as test_client: |
|||
assert len(database_connection_mocks) == 1 |
|||
[database_connection_mock] = database_connection_mocks |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 0 |
|||
|
|||
response = test_client.get('/users/user') |
|||
assert response.status_code == 200 |
|||
assert response.json() == { |
|||
"table_name": "users", |
|||
"record_id": "user", |
|||
} |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 0 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert database_connection_mock.enter_count == 1 |
|||
assert database_connection_mock.exit_count == 1 |
|||
assert database_connection_mock.get_record_count == 1 |
|||
|
|||
assert len(database_connection_mocks) == 1 |
Loading…
Reference in new issue