From affb3ca01707cf32181211c50b91a4120c8d14dc Mon Sep 17 00:00:00 2001 From: Alexander Pushkov Date: Tue, 5 Dec 2023 12:51:59 +0300 Subject: [PATCH] Add basic tests for using context managers as dependencies --- tests/test_dependency_contextmanager.py | 43 +++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/test_dependency_contextmanager.py b/tests/test_dependency_contextmanager.py index 03ef56c4d..ce0f56474 100644 --- a/tests/test_dependency_contextmanager.py +++ b/tests/test_dependency_contextmanager.py @@ -1,4 +1,5 @@ from typing import Dict +from contextlib import asynccontextmanager, contextmanager import pytest from fastapi import BackgroundTasks, Depends, FastAPI @@ -8,6 +9,8 @@ app = FastAPI() state = { "/async": "asyncgen not started", "/sync": "generator not started", + "/async_ctxmgr": "asyncgen_ctxmgr not started", + "/sync_ctxmgr": "generator_ctxmgr not started", "/async_raise": "asyncgen raise not started", "/sync_raise": "generator raise not started", "context_a": "not started a", @@ -47,6 +50,20 @@ def generator_state(state: Dict[str, str] = Depends(get_state)): state["/sync"] = "generator completed" +@asynccontextmanager +async def asyncgen_state_ctxmgr(state: Dict[str, str] = Depends(get_state)): + state["/async_ctxmgr"] = "asyncgen_ctxmgr started" + yield state["/async_ctxmgr"] + state["/async_ctxmgr"] = "asyncgen_ctxmgr completed" + + +@contextmanager +def generator_state_ctxmgr(state: Dict[str, str] = Depends(get_state)): + state["/sync_ctxmgr"] = "generator_ctxmgr started" + yield state["/sync_ctxmgr"] + state["/sync_ctxmgr"] = "generator_ctxmgr completed" + + async def asyncgen_state_try(state: Dict[str, str] = Depends(get_state)): state["/async_raise"] = "asyncgen raise started" try: @@ -93,6 +110,16 @@ async def get_sync(state: str = Depends(generator_state)): return state +@app.get("/async_ctxmgr") +async def get_async(state: str = Depends(asyncgen_state_ctxmgr)): + return state + + +@app.get("/sync_ctxmgr") +async def get_sync(state: str = Depends(generator_state_ctxmgr)): + return state + + @app.get("/async_raise") async def get_async_raise(state: str = Depends(asyncgen_state_try)): assert state == "asyncgen raise started" @@ -219,6 +246,22 @@ def test_sync_state(): assert state["/sync"] == "generator completed" +def test_async_ctxmgr_state(): + assert state["/async_ctxmgr"] == "asyncgen_ctxmgr not started" + response = client.get("/async_ctxmgr") + assert response.status_code == 200, response.text + assert response.json() == "asyncgen_ctxmgr started" + assert state["/async_ctxmgr"] == "asyncgen_ctxmgr completed" + + +def test_sync_ctxmgr_state(): + assert state["/sync_ctxmgr"] == "generator_ctxmgr not started" + response = client.get("/sync_ctxmgr") + assert response.status_code == 200, response.text + assert response.json() == "generator_ctxmgr started" + assert state["/sync_ctxmgr"] == "generator_ctxmgr completed" + + def test_async_raise_other(): assert state["/async_raise"] == "asyncgen raise not started" with pytest.raises(OtherDependencyError):