Browse Source

Merge branch 'feature/lifespan-scoped-dependencies' of github.com:UltimateLobster/fastapi into feature/lifespan-scoped-dependencies

pull/12529/head
Nir Schulman 8 months ago
parent
commit
17143bb3ac
  1. 9
      docs_src/dependencies/tutorial013a.py
  2. 6
      docs_src/dependencies/tutorial013a_an_py39.py
  3. 23
      docs_src/dependencies/tutorial013b.py
  4. 16
      docs_src/dependencies/tutorial013b_an_py39.py
  5. 13
      docs_src/dependencies/tutorial013c.py
  6. 18
      docs_src/dependencies/tutorial013c_an_py39.py
  7. 9
      docs_src/dependencies/tutorial013d.py
  8. 13
      docs_src/dependencies/tutorial013d_an_py39.py
  9. 10
      tests/test_tutorial/test_dependencies/test_tutorial013a.py
  10. 10
      tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py
  11. 29
      tests/test_tutorial/test_dependencies/test_tutorial013b.py
  12. 29
      tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py
  13. 10
      tests/test_tutorial/test_dependencies/test_tutorial013c.py
  14. 10
      tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py
  15. 10
      tests/test_tutorial/test_dependencies/test_tutorial013d.py
  16. 10
      tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py

9
docs_src/dependencies/tutorial013a.py

@ -18,6 +18,7 @@ class MyDatabaseConnection:
async def get_records(self, table_name: str) -> List[dict]:
pass
app = FastAPI()
@ -30,10 +31,14 @@ GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="li
@app.get("/users/")
async def read_users(database_connection: MyDatabaseConnection = GlobalDatabaseConnection):
async def read_users(
database_connection: MyDatabaseConnection = GlobalDatabaseConnection,
):
return await database_connection.get_records("users")
@app.get("/items/")
async def read_items(database_connection: MyDatabaseConnection = GlobalDatabaseConnection):
async def read_items(
database_connection: MyDatabaseConnection = GlobalDatabaseConnection,
):
return await database_connection.get_records("items")

6
docs_src/dependencies/tutorial013a_an_py39.py

@ -8,6 +8,7 @@ class MyDatabaseConnection:
"""
This is a mock just for example purposes.
"""
async def __aenter__(self) -> Self:
return self
@ -26,7 +27,10 @@ async def get_database_connection():
yield connection
GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")]
GlobalDatabaseConnection = Annotated[
MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")
]
@app.get("/users/")
async def read_users(database_connection: GlobalDatabaseConnection):

23
docs_src/dependencies/tutorial013b.py

@ -21,6 +21,7 @@ class MyDatabaseConnection:
async def get_record(self, table_name: str, record_id: str) -> dict:
pass
app = FastAPI()
@ -29,26 +30,36 @@ async def get_database_connection():
yield connection
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan")
DedicatedDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan", use_cache=False)
DedicatedDatabaseConnection = Depends(
get_database_connection, dependency_scope="lifespan", use_cache=False
)
@app.get("/groups/")
async def read_groups(database_connection: MyDatabaseConnection = DedicatedDatabaseConnection):
async def read_groups(
database_connection: MyDatabaseConnection = DedicatedDatabaseConnection,
):
return await database_connection.get_records("groups")
@app.get("/users/")
async def read_users(database_connection: MyDatabaseConnection = DedicatedDatabaseConnection):
async def read_users(
database_connection: MyDatabaseConnection = DedicatedDatabaseConnection,
):
return await database_connection.get_records("users")
@app.get("/items/")
async def read_items(database_connection: MyDatabaseConnection = GlobalDatabaseConnection):
async def read_items(
database_connection: MyDatabaseConnection = GlobalDatabaseConnection,
):
return await database_connection.get_records("items")
@app.get("/items/{item_id}")
async def read_item(
item_id: str = Path(),
database_connection: MyDatabaseConnection = GlobalDatabaseConnection
database_connection: MyDatabaseConnection = GlobalDatabaseConnection,
):
return await database_connection.get_record("items", item_id)

16
docs_src/dependencies/tutorial013b_an_py39.py

@ -21,6 +21,7 @@ class MyDatabaseConnection:
async def get_record(self, table_name: str, record_id: str) -> dict:
pass
app = FastAPI()
@ -29,13 +30,20 @@ async def get_database_connection():
yield connection
GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")]
DedicatedDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan", use_cache=False)]
GlobalDatabaseConnection = Annotated[
MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")
]
DedicatedDatabaseConnection = Annotated[
MyDatabaseConnection,
Depends(get_database_connection, dependency_scope="lifespan", use_cache=False),
]
@app.get("/groups/")
async def read_groups(database_connection: DedicatedDatabaseConnection):
return await database_connection.get_records("groups")
@app.get("/users/")
async def read_users(database_connection: DedicatedDatabaseConnection):
return await database_connection.get_records("users")
@ -45,9 +53,9 @@ async def read_users(database_connection: DedicatedDatabaseConnection):
async def read_items(database_connection: GlobalDatabaseConnection):
return await database_connection.get_records("items")
@app.get("/items/{item_id}")
async def read_item(
database_connection: GlobalDatabaseConnection,
item_id: Annotated[str, Path()]
database_connection: GlobalDatabaseConnection, item_id: Annotated[str, Path()]
):
return await database_connection.get_record("items", item_id)

13
docs_src/dependencies/tutorial013c.py

@ -9,6 +9,7 @@ class MyDatabaseConnection:
"""
This is a mock just for example purposes.
"""
connection_string: str
async def __aenter__(self) -> Self:
@ -20,6 +21,7 @@ class MyDatabaseConnection:
async def get_record(self, table_name: str, record_id: str) -> dict:
pass
app = FastAPI()
@ -28,20 +30,21 @@ async def get_configuration() -> dict:
"database_url": "sqlite:///database.db",
}
GlobalConfiguration = Depends(get_configuration, dependency_scope="lifespan")
async def get_database_connection(
configuration: dict = GlobalConfiguration
):
async def get_database_connection(configuration: dict = GlobalConfiguration):
async with MyDatabaseConnection(configuration["database_url"]) as connection:
yield connection
GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="lifespan")
@app.get("/users/{user_id}")
async def read_user(
database_connection: MyDatabaseConnection = GlobalDatabaseConnection,
user_id: str = Path()
database_connection: MyDatabaseConnection = GlobalDatabaseConnection,
user_id: str = Path(),
):
return await database_connection.get_record("users", user_id)

18
docs_src/dependencies/tutorial013c_an_py39.py

@ -10,6 +10,7 @@ class MyDatabaseConnection:
"""
This is a mock just for example purposes.
"""
connection_string: str
async def __aenter__(self) -> Self:
@ -33,20 +34,25 @@ async def get_configuration() -> dict:
"database_url": "sqlite:///database.db",
}
GlobalConfiguration = Annotated[dict, Depends(get_configuration, dependency_scope="lifespan")]
GlobalConfiguration = Annotated[
dict, Depends(get_configuration, dependency_scope="lifespan")
]
async def get_database_connection(configuration: GlobalConfiguration):
async with MyDatabaseConnection(
configuration["database_url"]) as connection:
async with MyDatabaseConnection(configuration["database_url"]) as connection:
yield connection
GlobalDatabaseConnection = Annotated[get_database_connection, Depends(get_database_connection, dependency_scope="lifespan")]
GlobalDatabaseConnection = Annotated[
get_database_connection,
Depends(get_database_connection, dependency_scope="lifespan"),
]
@app.get("/users/{user_id}")
async def read_user(
database_connection: GlobalDatabaseConnection,
user_id: Annotated[str, Path()]
database_connection: GlobalDatabaseConnection, user_id: Annotated[str, Path()]
):
return await database_connection.get_record("users", user_id)

9
docs_src/dependencies/tutorial013d.py

@ -21,6 +21,7 @@ class MyDatabaseConnection:
async def get_record(self, table_name: str, record_id: str) -> dict:
pass
app = FastAPI()
@ -33,14 +34,12 @@ GlobalDatabaseConnection = Depends(get_database_connection, dependency_scope="li
async def get_user_record(
database_connection: MyDatabaseConnection = GlobalDatabaseConnection,
user_id: str = Path()
database_connection: MyDatabaseConnection = GlobalDatabaseConnection,
user_id: str = Path(),
) -> dict:
return await database_connection.get_record("users", user_id)
@app.get("/users/{user_id}")
async def read_user(
user_record: dict = Depends(get_user_record)
):
async def read_user(user_record: dict = Depends(get_user_record)):
return user_record

13
docs_src/dependencies/tutorial013d_an_py39.py

@ -21,6 +21,7 @@ class MyDatabaseConnection:
async def get_record(self, table_name: str, record_id: str) -> dict:
pass
app = FastAPI()
@ -29,17 +30,17 @@ async def get_database_connection():
yield connection
GlobalDatabaseConnection = Annotated[MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")]
GlobalDatabaseConnection = Annotated[
MyDatabaseConnection, Depends(get_database_connection, dependency_scope="lifespan")
]
async def get_user_record(
database_connection: GlobalDatabaseConnection,
user_id: Annotated[str, Path()]
database_connection: GlobalDatabaseConnection, user_id: Annotated[str, Path()]
) -> dict:
return await database_connection.get_record("users", user_id)
@app.get("/users/{user_id}")
async def read_user(
user_record: Annotated[dict, Depends(get_user_record)]
):
async def read_user(user_record: Annotated[dict, Depends(get_user_record)]):
return user_record

10
tests/test_tutorial/test_dependencies/test_tutorial013a.py

@ -34,11 +34,7 @@ class MockDatabaseConnection:
def database_connection_mock(monkeypatch) -> MockDatabaseConnection:
mock = MockDatabaseConnection()
monkeypatch.setattr(
MyDatabaseConnection,
"__new__",
lambda *args, **kwargs: mock
)
monkeypatch.setattr(MyDatabaseConnection, "__new__", lambda *args, **kwargs: mock)
return mock
@ -50,13 +46,13 @@ def test_dependency_usage(database_connection_mock):
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 0
response = test_client.get('/users')
response = test_client.get("/users")
assert response.status_code == 200
assert response.json() == []
assert database_connection_mock.get_records_count == 1
response = test_client.get('/items')
response = test_client.get("/items")
assert response.status_code == 200
assert response.json() == []

10
tests/test_tutorial/test_dependencies/test_tutorial013a_an_py39.py

@ -35,11 +35,7 @@ class MockDatabaseConnection:
def database_connection_mock(monkeypatch) -> MockDatabaseConnection:
mock = MockDatabaseConnection()
monkeypatch.setattr(
MyDatabaseConnection,
"__new__",
lambda *args, **kwargs: mock
)
monkeypatch.setattr(MyDatabaseConnection, "__new__", lambda *args, **kwargs: mock)
return mock
@ -52,13 +48,13 @@ def test_dependency_usage(database_connection_mock):
assert database_connection_mock.enter_count == 1
assert database_connection_mock.exit_count == 0
response = test_client.get('/users')
response = test_client.get("/users")
assert response.status_code == 200
assert response.json() == []
assert database_connection_mock.get_records_count == 1
response = test_client.get('/items')
response = test_client.get("/items")
assert response.status_code == 200
assert response.json() == []

29
tests/test_tutorial/test_dependencies/test_tutorial013b.py

@ -40,22 +40,17 @@ class MockDatabaseConnection:
}
@pytest.fixture
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]:
connections = []
def _get_new_connection_mock(*args, **kwargs):
mock = MockDatabaseConnection()
connections.append(mock)
return mock
monkeypatch.setattr(
MyDatabaseConnection,
"__new__",
_get_new_connection_mock
)
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock)
return connections
@ -70,7 +65,7 @@ def test_dependency_usage(database_connection_mocks):
assert connection.get_records_count == 0
assert connection.get_record_count == 0
response = test_client.get('/users')
response = test_client.get("/users")
assert response.status_code == 200
assert response.json() == []
@ -80,9 +75,11 @@ def test_dependency_usage(database_connection_mocks):
users_connection = connection
break
assert users_connection is not None, "No connection was found for users endpoint"
assert (
users_connection is not None
), "No connection was found for users endpoint"
response = test_client.get('/groups')
response = test_client.get("/groups")
assert response.status_code == 200
assert response.json() == []
@ -92,7 +89,9 @@ def test_dependency_usage(database_connection_mocks):
groups_connection = connection
break
assert groups_connection is not None, "No connection was found for groups endpoint"
assert (
groups_connection is not None
), "No connection was found for groups endpoint"
assert groups_connection.get_records_count == 1
items_connection = None
@ -101,16 +100,18 @@ def test_dependency_usage(database_connection_mocks):
items_connection = connection
break
assert items_connection is not None, "No connection was found for items endpoint"
assert (
items_connection is not None
), "No connection was found for items endpoint"
response = test_client.get('/items')
response = test_client.get("/items")
assert response.status_code == 200
assert response.json() == []
assert items_connection.get_records_count == 1
assert items_connection.get_record_count == 0
response = test_client.get('/items/asd')
response = test_client.get("/items/asd")
assert response.status_code == 200
assert response.json() == {
"table_name": "items",

29
tests/test_tutorial/test_dependencies/test_tutorial013b_an_py39.py

@ -41,22 +41,17 @@ class MockDatabaseConnection:
}
@pytest.fixture
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]:
connections = []
def _get_new_connection_mock(*args, **kwargs):
mock = MockDatabaseConnection()
connections.append(mock)
return mock
monkeypatch.setattr(
MyDatabaseConnection,
"__new__",
_get_new_connection_mock
)
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock)
return connections
@ -72,7 +67,7 @@ def test_dependency_usage(database_connection_mocks):
assert connection.get_records_count == 0
assert connection.get_record_count == 0
response = test_client.get('/users')
response = test_client.get("/users")
assert response.status_code == 200
assert response.json() == []
@ -82,9 +77,11 @@ def test_dependency_usage(database_connection_mocks):
users_connection = connection
break
assert users_connection is not None, "No connection was found for users endpoint"
assert (
users_connection is not None
), "No connection was found for users endpoint"
response = test_client.get('/groups')
response = test_client.get("/groups")
assert response.status_code == 200
assert response.json() == []
@ -94,7 +91,9 @@ def test_dependency_usage(database_connection_mocks):
groups_connection = connection
break
assert groups_connection is not None, "No connection was found for groups endpoint"
assert (
groups_connection is not None
), "No connection was found for groups endpoint"
assert groups_connection.get_records_count == 1
items_connection = None
@ -103,16 +102,18 @@ def test_dependency_usage(database_connection_mocks):
items_connection = connection
break
assert items_connection is not None, "No connection was found for items endpoint"
assert (
items_connection is not None
), "No connection was found for items endpoint"
response = test_client.get('/items')
response = test_client.get("/items")
assert response.status_code == 200
assert response.json() == []
assert items_connection.get_records_count == 1
assert items_connection.get_record_count == 0
response = test_client.get('/items/asd')
response = test_client.get("/items/asd")
assert response.status_code == 200
assert response.json() == {
"table_name": "items",

10
tests/test_tutorial/test_dependencies/test_tutorial013c.py

@ -34,21 +34,17 @@ class MockDatabaseConnection:
}
@pytest.fixture
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]:
connections = []
def _get_new_connection_mock(cls, url):
mock = MockDatabaseConnection(url)
connections.append(mock)
return mock
monkeypatch.setattr(
MyDatabaseConnection,
"__new__",
_get_new_connection_mock
)
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock)
return connections
@ -64,7 +60,7 @@ def test_dependency_usage(database_connection_mocks):
assert database_connection_mock.exit_count == 0
assert database_connection_mock.get_record_count == 0
response = test_client.get('/users/user')
response = test_client.get("/users/user")
assert response.status_code == 200
assert response.json() == {
"table_name": "users",

10
tests/test_tutorial/test_dependencies/test_tutorial013c_an_py39.py

@ -35,21 +35,17 @@ class MockDatabaseConnection:
}
@pytest.fixture
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]:
connections = []
def _get_new_connection_mock(cls, url):
mock = MockDatabaseConnection(url)
connections.append(mock)
return mock
monkeypatch.setattr(
MyDatabaseConnection,
"__new__",
_get_new_connection_mock
)
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock)
return connections
@ -66,7 +62,7 @@ def test_dependency_usage(database_connection_mocks):
assert database_connection_mock.exit_count == 0
assert database_connection_mock.get_record_count == 0
response = test_client.get('/users/user')
response = test_client.get("/users/user")
assert response.status_code == 200
assert response.json() == {
"table_name": "users",

10
tests/test_tutorial/test_dependencies/test_tutorial013d.py

@ -33,21 +33,17 @@ class MockDatabaseConnection:
}
@pytest.fixture
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]:
connections = []
def _get_new_connection_mock(*args, **kwargs):
mock = MockDatabaseConnection()
connections.append(mock)
return mock
monkeypatch.setattr(
MyDatabaseConnection,
"__new__",
_get_new_connection_mock
)
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock)
return connections
@ -62,7 +58,7 @@ def test_dependency_usage(database_connection_mocks):
assert database_connection_mock.exit_count == 0
assert database_connection_mock.get_record_count == 0
response = test_client.get('/users/user')
response = test_client.get("/users/user")
assert response.status_code == 200
assert response.json() == {
"table_name": "users",

10
tests/test_tutorial/test_dependencies/test_tutorial013d_an_py39.py

@ -34,21 +34,17 @@ class MockDatabaseConnection:
}
@pytest.fixture
def database_connection_mocks(monkeypatch) -> List[MockDatabaseConnection]:
connections = []
def _get_new_connection_mock(*args, **kwargs):
mock = MockDatabaseConnection()
connections.append(mock)
return mock
monkeypatch.setattr(
MyDatabaseConnection,
"__new__",
_get_new_connection_mock
)
monkeypatch.setattr(MyDatabaseConnection, "__new__", _get_new_connection_mock)
return connections
@ -64,7 +60,7 @@ def test_dependency_usage(database_connection_mocks):
assert database_connection_mock.exit_count == 0
assert database_connection_mock.get_record_count == 0
response = test_client.get('/users/user')
response = test_client.get("/users/user")
assert response.status_code == 200
assert response.json() == {
"table_name": "users",

Loading…
Cancel
Save