diff --git a/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md new file mode 100644 index 000000000..d0aea0a16 --- /dev/null +++ b/docs/en/docs/tutorial/dependencies/lifespan-scoped-dependencies.md @@ -0,0 +1,99 @@ +# Lifespan Scoped Dependencies + +So far we've used dependencies which are "endpoint scoped". Meaning, they are +called again and again for every incoming request to the endpoint. However, +this is not ideal for all kinds of dependencies. + +Sometimes dependencies have a large setup/teardown time, or there is a need +for their value to be shared throughout the lifespan of the application. An +example of this would be a connection to a database. Databases are typically +less efficient when working with lots of connections and would prefer that +clients would create a single connection for their operations. + +For such cases, you might want to use "lifespan scoped" dependencies. + +## Intro + +Lifespan scoped dependencies work similarly to the dependencies we've worked +with so far (which are endpoint scoped). However, they are called once and only +once in the application's lifespan (instead of being called again and again for +every request). The returned value will be shared across all requests that need +it. + + +## Create a lifespan scoped dependency + +You may declare a dependency as a lifespan scoped dependency by passing +`dependency_scope="lifespan"` to the `Depends` function: + +{* ../../docs_src/dependencies/tutorial013a_an_py39.py hl[16] *} + +/// tip + +In the example above we saved the annotation to a separate variable, and then +reused it in our endpoints. This is not a requirement, we could also declare +the exact same annotation in both endpoints. However, it is recommended that you +do save the annotation to a variable so you won't accidentally forget to pass +`dependency_scope="lifespan"` to some of the endpoints (Causing the endpoint +to create a new database connection for every request). + +/// + +In this example, the `get_database_connection` dependency will be executed once, +during the application's startup. **FastAPI** will internally save the resulting +connection object, and whenever the `read_users` and `read_items` endpoints are +called, they will be using the previously saved connection. Once the application +shuts down, **FastAPI** will make sure to gracefully close the connection object. + +## The `use_cache` argument + +The `use_cache` argument works similarly to the way it worked with endpoint +scoped dependencies. Meaning as **FastAPI** gathers lifespan scoped dependencies, it +will cache dependencies it already encountered before. However, you can disable +this behavior by passing `use_cache=False` to `Depends`: + +{* ../../docs_src/dependencies/tutorial013b_an_py39.py hl[16] *} + +In this example, the `read_users` and `read_groups` endpoints are using +`use_cache=False` whereas the `read_items` and `read_item` are using +`use_cache=True`. That means that we'll have a total of 3 connections created +for the duration of the application's lifespan. One connection will be shared +across all requests for the `read_items` and `read_item` endpoints. A second +connection will be shared across all requests for the `read_users` endpoint. The +third and final connection will be shared across all requests for the +`read_groups` endpoint. + + +## Lifespan Scoped Sub-Dependencies +Just like with endpoint scoped dependencies, lifespan scoped dependencies may +use other lifespan scoped sub-dependencies themselves: + +{* ../../docs_src/dependencies/tutorial013c_an_py39.py hl[16] *} + +Endpoint scoped dependencies may use lifespan scoped sub dependencies as well: + +{* ../../docs_src/dependencies/tutorial013d_an_py39.py hl[16] *} + +/// note + +You can pass `dependency_scope="endpoint"` if you wish to explicitly specify +that a dependency is endpoint scoped. It will work the same as not specifying +a dependency scope at all. + +/// + +As you can see, regardless of the scope, dependencies can use lifespan scoped +sub-dependencies. + +## Dependency Scope Conflicts +By definition, lifespan scoped dependencies are being setup in the application's +startup process, before any request is ever being made to any endpoint. +Therefore, it is not possible for a lifespan scoped dependency to use any +parameters that require the scope of an endpoint. + +That includes but not limited to: + * Parts of the request (like `Body`, `Query` and `Path`) + * The request/response objects themselves (like `Request`, `Response` and `WebSocket`) + * Endpoint scoped sub-dependencies. + +Defining a dependency with such parameters will raise an `InvalidDependencyScope` error. diff --git a/docs_src/dependencies/tutorial013a.py b/docs_src/dependencies/tutorial013a.py new file mode 100644 index 000000000..e014289d2 --- /dev/null +++ b/docs_src/dependencies/tutorial013a.py @@ -0,0 +1,39 @@ +from typing import List + +from fastapi import Depends, FastAPI +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> List[dict]: + pass + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") + + +@app.get("/users/") +async def read_users(database_connection: MyDatabaseConnection = GlobalDatabaseConnection): + return await database_connection.get_records("users") + + +@app.get("/items/") +async def read_items(database_connection: MyDatabaseConnection = GlobalDatabaseConnection): + return await database_connection.get_records("items") diff --git a/docs_src/dependencies/tutorial013a_an_py39.py b/docs_src/dependencies/tutorial013a_an_py39.py new file mode 100644 index 000000000..c2e8c6672 --- /dev/null +++ b/docs_src/dependencies/tutorial013a_an_py39.py @@ -0,0 +1,38 @@ +from typing import Annotated + +from fastapi import Depends, FastAPI +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> list[dict]: + pass + + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")] + +@app.get("/users/") +async def read_users(database_connection: GlobalDatabaseConnection): + return await database_connection.get_records("users") + + +@app.get("/items/") +async def read_items(database_connection: GlobalDatabaseConnection): + return await database_connection.get_records("items") diff --git a/docs_src/dependencies/tutorial013b.py b/docs_src/dependencies/tutorial013b.py new file mode 100644 index 000000000..0b9907de4 --- /dev/null +++ b/docs_src/dependencies/tutorial013b.py @@ -0,0 +1,54 @@ +from typing import List + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> List[dict]: + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + + +GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") +DedicatedDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan", use_cache=False) + +@app.get("/groups/") +async def read_groups(database_connection: MyDatabaseConnection = DedicatedDatabaseConnection): + return await database_connection.get_records("groups") + +@app.get("/users/") +async def read_users(database_connection: MyDatabaseConnection = DedicatedDatabaseConnection): + return await database_connection.get_records("users") + + +@app.get("/items/") +async def read_items(database_connection: MyDatabaseConnection = GlobalDatabaseConnection): + return await database_connection.get_records("items") + +@app.get("/items/{item_id}") +async def read_item( + item_id: str = Path(), + database_connection: MyDatabaseConnection = GlobalDatabaseConnection +): + return await database_connection.get_record("items", item_id) diff --git a/docs_src/dependencies/tutorial013b_an_py39.py b/docs_src/dependencies/tutorial013b_an_py39.py new file mode 100644 index 000000000..c5274417d --- /dev/null +++ b/docs_src/dependencies/tutorial013b_an_py39.py @@ -0,0 +1,53 @@ +from typing import Annotated + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> list[dict]: + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")] +DedicatedDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan", use_cache=False)] + +@app.get("/groups/") +async def read_groups(database_connection: DedicatedDatabaseConnection): + return await database_connection.get_records("groups") + +@app.get("/users/") +async def read_users(database_connection: DedicatedDatabaseConnection): + return await database_connection.get_records("users") + + +@app.get("/items/") +async def read_items(database_connection: GlobalDatabaseConnection): + return await database_connection.get_records("items") + +@app.get("/items/{item_id}") +async def read_item( + database_connection: GlobalDatabaseConnection, + item_id: Annotated[str, Path()] +): + return await database_connection.get_record("items", item_id) diff --git a/docs_src/dependencies/tutorial013c.py b/docs_src/dependencies/tutorial013c.py new file mode 100644 index 000000000..c4eb99f25 --- /dev/null +++ b/docs_src/dependencies/tutorial013c.py @@ -0,0 +1,47 @@ +from dataclasses import dataclass + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +@dataclass +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + connection_string: str + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + +app = FastAPI() + + +async def get_configuration() -> dict: + return { + "database_url": "sqlite:///database.db", + } + +GlobalConfiguration = Depends(get_configuration, dependency_scope="lifespan") + + +async def get_database_connection( + configuration: dict = GlobalConfiguration +): + async with MyDatabaseConnection(configuration["database_url"]) as connection: + yield connection + +GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") + +@app.get("/users/{user_id}") +async def read_user( + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, + user_id: str = Path() +): + return await database_connection.get_record("users", user_id) diff --git a/docs_src/dependencies/tutorial013c_an_py39.py b/docs_src/dependencies/tutorial013c_an_py39.py new file mode 100644 index 000000000..1830a4b4e --- /dev/null +++ b/docs_src/dependencies/tutorial013c_an_py39.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import Annotated, List + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +@dataclass +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + connection_string: str + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> List[dict]: + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + + +app = FastAPI() + + +async def get_configuration() -> dict: + return { + "database_url": "sqlite:///database.db", + } + +GlobalConfiguration = Annotated[dict, Depends(get_configuration, dependency_scope="lifespan")] + + +async def get_database_connection(configuration: GlobalConfiguration): + async with MyDatabaseConnection( + configuration["database_url"]) as connection: + yield connection + +GlobalDatabaseConnection = Annotated[get_database_connection, Depends(get_database_connection, dependency_scope="lifespan")] + + +@app.get("/users/{user_id}") +async def read_user( + database_connection: GlobalDatabaseConnection, + user_id: Annotated[str, Path()] +): + return await database_connection.get_record("users", user_id) diff --git a/docs_src/dependencies/tutorial013d.py b/docs_src/dependencies/tutorial013d.py new file mode 100644 index 000000000..dc6441620 --- /dev/null +++ b/docs_src/dependencies/tutorial013d.py @@ -0,0 +1,46 @@ +from typing import List + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> List[dict]: + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan") + + +async def get_user_record( + database_connection: MyDatabaseConnection = GlobalDatabaseConnection, + user_id: str = Path() +) -> dict: + return await database_connection.get_record("users", user_id) + + +@app.get("/users/{user_id}") +async def read_user( + user_record: dict = Depends(get_user_record) +): + return user_record diff --git a/docs_src/dependencies/tutorial013d_an_py39.py b/docs_src/dependencies/tutorial013d_an_py39.py new file mode 100644 index 000000000..41aae4c47 --- /dev/null +++ b/docs_src/dependencies/tutorial013d_an_py39.py @@ -0,0 +1,45 @@ +from typing import Annotated + +from fastapi import Depends, FastAPI, Path +from typing_extensions import Self + + +class MyDatabaseConnection: + """ + This is a mock just for example purposes. + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_records(self, table_name: str) -> list[dict]: + pass + + async def get_record(self, table_name: str, record_id: str) -> dict: + pass + +app = FastAPI() + + +async def get_database_connection(): + async with MyDatabaseConnection() as connection: + yield connection + + +GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")] + + +async def get_user_record( + database_connection: GlobalDatabaseConnection, + user_id: Annotated[str, Path()] +) -> dict: + return await database_connection.get_record("users", user_id) + +@app.get("/users/{user_id}") +async def read_user( + user_record: Annotated[dict, Depends(get_user_record)] +): + return user_record diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013a.py b/tests/test_tutorial/test_dependencies/test_tutorial013a.py new file mode 100644 index 000000000..d8a6884c0 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013a.py @@ -0,0 +1,69 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013a import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_records_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_records(self, table_name: str) -> List[dict]: + self.get_records_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return [] + + +@pytest.fixture +def database_connection_mock(monkeypatch) -> MockDatabaseConnection: + mock = MockDatabaseConnection() + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + lambda *args, **kwargs: mock + ) + + return mock + + +def test_dependency_usage(database_connection_mock): + assert database_connection_mock.enter_count == 0 + assert database_connection_mock.exit_count == 0 + with TestClient(app) as test_client: + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + + response = test_client.get('/users') + assert response.status_code == 200 + assert response.json() == [] + + assert database_connection_mock.get_records_count == 1 + + response = test_client.get('/items') + assert response.status_code == 200 + assert response.json() == [] + + assert database_connection_mock.get_records_count == 2 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py new file mode 100644 index 000000000..05af6b665 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py @@ -0,0 +1,69 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013a_an_py39 import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_records_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_records(self, table_name: str) -> List[dict]: + self.get_records_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return [] + + +@pytest.fixture +def database_connection_mock(monkeypatch) -> MockDatabaseConnection: + mock = MockDatabaseConnection() + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + lambda *args, **kwargs: mock + ) + + return mock + + +def test_dependency_usage(database_connection_mock): + assert database_connection_mock.enter_count == 0 + assert database_connection_mock.exit_count == 0 + with TestClient(app) as test_client: + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + + response = test_client.get('/users') + assert response.status_code == 200 + assert response.json() == [] + + assert database_connection_mock.get_records_count == 1 + + response = test_client.get('/items') + assert response.status_code == 200 + assert response.json() == [] + + assert database_connection_mock.get_records_count == 2 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013b.py b/tests/test_tutorial/test_dependencies/test_tutorial013b.py new file mode 100644 index 000000000..88942d0f3 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013b.py @@ -0,0 +1,129 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013b import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_records_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_records(self, table_name: str) -> List[dict]: + self.get_records_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return [] + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return { + "table_name": table_name, + "record_id": record_id, + } + + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + def _get_new_connection_mock(*args, **kwargs): + mock = MockDatabaseConnection() + connections.append(mock) + + return mock + + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + _get_new_connection_mock + ) + return connections + + +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 3 + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 0 + assert connection.get_records_count == 0 + assert connection.get_record_count == 0 + + response = test_client.get('/users') + assert response.status_code == 200 + assert response.json() == [] + + users_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 1: + users_connection = connection + break + + assert users_connection is not None, "No connection was found for users endpoint" + + response = test_client.get('/groups') + assert response.status_code == 200 + assert response.json() == [] + + groups_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 1 and connection is not users_connection: + groups_connection = connection + break + + assert groups_connection is not None, "No connection was found for groups endpoint" + assert groups_connection.get_records_count == 1 + + items_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 0: + items_connection = connection + break + + assert items_connection is not None, "No connection was found for items endpoint" + + response = test_client.get('/items') + assert response.status_code == 200 + assert response.json() == [] + + assert items_connection.get_records_count == 1 + assert items_connection.get_record_count == 0 + + response = test_client.get('/items/asd') + assert response.status_code == 200 + assert response.json() == { + "table_name": "items", + "record_id": "asd", + } + + assert items_connection.get_records_count == 1 + assert items_connection.get_record_count == 1 + + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 0 + + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py new file mode 100644 index 000000000..6c0c2132f --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py @@ -0,0 +1,129 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013b_an_py39 import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_records_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_records(self, table_name: str) -> List[dict]: + self.get_records_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return [] + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_records(self, table_name) + return { + "table_name": table_name, + "record_id": record_id, + } + + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + def _get_new_connection_mock(*args, **kwargs): + mock = MockDatabaseConnection() + connections.append(mock) + + return mock + + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + _get_new_connection_mock + ) + return connections + + +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 3 + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 0 + assert connection.get_records_count == 0 + assert connection.get_record_count == 0 + + response = test_client.get('/users') + assert response.status_code == 200 + assert response.json() == [] + + users_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 1: + users_connection = connection + break + + assert users_connection is not None, "No connection was found for users endpoint" + + response = test_client.get('/groups') + assert response.status_code == 200 + assert response.json() == [] + + groups_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 1 and connection is not users_connection: + groups_connection = connection + break + + assert groups_connection is not None, "No connection was found for groups endpoint" + assert groups_connection.get_records_count == 1 + + items_connection = None + for connection in database_connection_mocks: + if connection.get_records_count == 0: + items_connection = connection + break + + assert items_connection is not None, "No connection was found for items endpoint" + + response = test_client.get('/items') + assert response.status_code == 200 + assert response.json() == [] + + assert items_connection.get_records_count == 1 + assert items_connection.get_record_count == 0 + + response = test_client.get('/items/asd') + assert response.status_code == 200 + assert response.json() == { + "table_name": "items", + "record_id": "asd", + } + + assert items_connection.get_records_count == 1 + assert items_connection.get_record_count == 1 + + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 0 + + for connection in database_connection_mocks: + assert connection.enter_count == 1 + assert connection.exit_count == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013c.py b/tests/test_tutorial/test_dependencies/test_tutorial013c.py new file mode 100644 index 000000000..09390cd99 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013c.py @@ -0,0 +1,82 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013c import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self, url: str): + self.url = url + self.enter_count = 0 + self.exit_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_record(self, table_name, record_id) + return { + "table_name": table_name, + "record_id": record_id, + } + + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + def _get_new_connection_mock(cls, url): + mock = MockDatabaseConnection(url) + connections.append(mock) + + return mock + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + _get_new_connection_mock + ) + return connections + + +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 1 + [database_connection_mock] = database_connection_mocks + + assert database_connection_mock.url == "sqlite:///database.db" + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 0 + + response = test_client.get('/users/user') + assert response.status_code == 200 + assert response.json() == { + "table_name": "users", + "record_id": "user", + } + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 1 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 + assert database_connection_mock.get_record_count == 1 + + assert len(database_connection_mocks) == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py new file mode 100644 index 000000000..03f01145a --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py @@ -0,0 +1,82 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013c_an_py39 import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self, url: str): + self.url = url + self.enter_count = 0 + self.exit_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_record(self, table_name, record_id) + return { + "table_name": table_name, + "record_id": record_id, + } + + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + def _get_new_connection_mock(cls, url): + mock = MockDatabaseConnection(url) + connections.append(mock) + + return mock + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + _get_new_connection_mock + ) + return connections + + +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 1 + [database_connection_mock] = database_connection_mocks + + assert database_connection_mock.url == "sqlite:///database.db" + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 0 + + response = test_client.get('/users/user') + assert response.status_code == 200 + assert response.json() == { + "table_name": "users", + "record_id": "user", + } + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 1 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 + assert database_connection_mock.get_record_count == 1 + + assert len(database_connection_mocks) == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013d.py b/tests/test_tutorial/test_dependencies/test_tutorial013d.py new file mode 100644 index 000000000..dccf6a006 --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013d.py @@ -0,0 +1,80 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013d import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_record(self, table_name, record_id) + return { + "table_name": table_name, + "record_id": record_id, + } + + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + def _get_new_connection_mock(*args, **kwargs): + mock = MockDatabaseConnection() + connections.append(mock) + + return mock + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + _get_new_connection_mock + ) + return connections + + +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 1 + [database_connection_mock] = database_connection_mocks + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 0 + + response = test_client.get('/users/user') + assert response.status_code == 200 + assert response.json() == { + "table_name": "users", + "record_id": "user", + } + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 1 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 + assert database_connection_mock.get_record_count == 1 + + assert len(database_connection_mocks) == 1 diff --git a/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py b/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py new file mode 100644 index 000000000..4f27fccca --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py @@ -0,0 +1,80 @@ +from typing import List + +import pytest +from starlette.testclient import TestClient +from typing_extensions import Self + +from docs_src.dependencies.tutorial013d_an_py39 import MyDatabaseConnection, app + + +class MockDatabaseConnection: + def __init__(self): + self.enter_count = 0 + self.exit_count = 0 + self.get_record_count = 0 + + async def __aenter__(self) -> Self: + self.enter_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aenter__(self) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_count += 1 + # Called for the sake of coverage. + return await MyDatabaseConnection.__aexit__(self, exc_type, exc_val, exc_tb) + + async def get_record(self, table_name: str, record_id: str) -> dict: + self.get_record_count += 1 + # Called for the sake of coverage. + await MyDatabaseConnection.get_record(self, table_name, record_id) + return { + "table_name": table_name, + "record_id": record_id, + } + + + +@pytest.fixture +def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]: + connections = [] + def _get_new_connection_mock(*args, **kwargs): + mock = MockDatabaseConnection() + connections.append(mock) + + return mock + + monkeypatch.setattr( + MyDatabaseConnection, + "__new__", + _get_new_connection_mock + ) + return connections + + +def test_dependency_usage(database_connection_mocks): + assert len(database_connection_mocks) == 0 + + with TestClient(app) as test_client: + assert len(database_connection_mocks) == 1 + [database_connection_mock] = database_connection_mocks + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 0 + + response = test_client.get('/users/user') + assert response.status_code == 200 + assert response.json() == { + "table_name": "users", + "record_id": "user", + } + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 0 + assert database_connection_mock.get_record_count == 1 + + assert database_connection_mock.enter_count == 1 + assert database_connection_mock.exit_count == 1 + assert database_connection_mock.get_record_count == 1 + + assert len(database_connection_mocks) == 1