diff --git a/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md b/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md index 2b97ba39e..a6cf2f305 100644 --- a/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md +++ b/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md @@ -23,6 +23,8 @@ In fact, FastAPI uses those two decorators internally. /// +And if you have a context manager already, you can use it as a dependency, too. + ## A database dependency with `yield` For example, you could use this to create a database session and close it after finishing. @@ -240,36 +242,37 @@ When the `with` block finishes, it makes sure to close the file, even if there w When you create a dependency with `yield`, **FastAPI** will internally create a context manager for it, and combine it with some other related tools. -### Using context managers in dependencies with `yield` +### Context Managers as dependencies -/// warning +You don’t have to create a Context Manager to use FastAPI dependencies. -This is, more or less, an "advanced" idea. +But sometimes you might want to use a dependency both inside and outside FastAPI. For example, you might want to connect to a database in a DB migration script. You can create a Context Manager manually then: -If you are just starting with **FastAPI** you might want to skip it for now. +{* ../../docs_src/dependencies/tutorial010_ctx.py *} -/// +This way you can use it with `Depends()` in your app, and as a regular Context Manager everywhere else: -In Python, you can create Context Managers by creating a class with two methods: `__enter__()` and `__exit__()`. +{* ../../docs_src/dependencies/tutorial010_ctx_usage.py *} -You can also use them inside of **FastAPI** dependencies with `yield` by using -`with` or `async with` statements inside of the dependency function: -{* ../../docs_src/dependencies/tutorial010.py hl[1:9,13] *} +### Defining Context Managers as classes -/// tip +/// warning -Another way to create a context manager is with: +This is, more or less, an "advanced" idea. -* `@contextlib.contextmanager` or -* `@contextlib.asynccontextmanager` +If you are just starting with **FastAPI** you might want to skip it for now. + +/// + +You can also create Context Managers by creating a class with two methods: `__enter__()` and `__exit__()`. -using them to decorate a function with a single `yield`. +Like the Context Managers created with `contextlib`, you can use them inside of **FastAPI** dependencies directly: -That's what **FastAPI** uses internally for dependencies with `yield`. +{* ../../docs_src/dependencies/tutorial010.py hl[1:9,13] *} -But you don't have to use the decorators for FastAPI dependencies (and you shouldn't). +/// note | Technical Details -FastAPI will do it for you internally. +Internally, FastAPI calls the dependency function and then checks if the result has either an `__enter__()` or an `__aenter__()` method. If that’s the case, it will treat it as a context manager. /// diff --git a/docs_src/dependencies/tutorial010.py b/docs_src/dependencies/tutorial010.py index c27f1b170..3b0f024bc 100644 --- a/docs_src/dependencies/tutorial010.py +++ b/docs_src/dependencies/tutorial010.py @@ -9,6 +9,5 @@ class MySuperContextManager: self.db.close() -async def get_db(): - with MySuperContextManager() as db: - yield db +@app.get("/") +async def get_root(db: Annotated[DBSession, Depends(MySuperContextManager)]): ... diff --git a/docs_src/dependencies/tutorial010_ctx.py b/docs_src/dependencies/tutorial010_ctx.py new file mode 100644 index 000000000..1fd180c8f --- /dev/null +++ b/docs_src/dependencies/tutorial010_ctx.py @@ -0,0 +1,10 @@ +from contextlib import asynccontextmanager + + +@asynccontextmanager +async def get_db(): + db = DBSession() + try: + yield db + finally: + db.close() diff --git a/docs_src/dependencies/tutorial010_ctx_usage.py b/docs_src/dependencies/tutorial010_ctx_usage.py new file mode 100644 index 000000000..11a3ccf2e --- /dev/null +++ b/docs_src/dependencies/tutorial010_ctx_usage.py @@ -0,0 +1,3 @@ +async def migrate(): + with get_db() as db: + ... diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 081b63a8b..1cbedcce9 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -637,7 +637,14 @@ async def solve_dependencies( elif is_coroutine_callable(call): solved = await call(**solved_result.values) else: - solved = await run_in_threadpool(call, **solved_result.values) + called = await run_in_threadpool(call, **solved_result.values) + if hasattr(called, "__aenter__"): + solved = await async_exit_stack.enter_async_context(called) + elif hasattr(called, "__enter__"): + cm = contextmanager_in_threadpool(called) + solved = await async_exit_stack.enter_async_context(cm) + else: + solved = called if sub_dependant.name is not None: values[sub_dependant.name] = solved if sub_dependant.cache_key not in dependency_cache: diff --git a/pyproject.toml b/pyproject.toml index 7709451ff..5d720e05d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -238,6 +238,8 @@ ignore = [ "docs_src/custom_request_and_route/tutorial002.py" = ["B904"] "docs_src/dependencies/tutorial008_an.py" = ["F821"] "docs_src/dependencies/tutorial008_an_py39.py" = ["F821"] +"docs_src/dependencies/tutorial010_ctx.py" = ["F821"] +"docs_src/dependencies/tutorial010_ctx_usage.py" = ["F821", "F841"] "docs_src/query_params_str_validations/tutorial012_an.py" = ["B006"] "docs_src/query_params_str_validations/tutorial012_an_py39.py" = ["B006"] "docs_src/query_params_str_validations/tutorial013_an.py" = ["B006"] diff --git a/tests/test_dependency_contextmanager.py b/tests/test_dependency_contextmanager.py index 039c423b9..8849e8cf3 100644 --- a/tests/test_dependency_contextmanager.py +++ b/tests/test_dependency_contextmanager.py @@ -1,4 +1,5 @@ import json +from contextlib import asynccontextmanager, contextmanager from typing import Dict import pytest @@ -10,6 +11,8 @@ app = FastAPI() state = { "/async": "asyncgen not started", "/sync": "generator not started", + "/async_ctxmgr": "asyncgen_ctxmgr not started", + "/sync_ctxmgr": "generator_ctxmgr not started", "/async_raise": "asyncgen raise not started", "/sync_raise": "generator raise not started", "context_a": "not started a", @@ -49,6 +52,20 @@ def generator_state(state: Dict[str, str] = Depends(get_state)): state["/sync"] = "generator completed" +@asynccontextmanager +async def asyncgen_state_ctxmgr(state: Dict[str, str] = Depends(get_state)): + state["/async_ctxmgr"] = "asyncgen_ctxmgr started" + yield state["/async_ctxmgr"] + state["/async_ctxmgr"] = "asyncgen_ctxmgr completed" + + +@contextmanager +def generator_state_ctxmgr(state: Dict[str, str] = Depends(get_state)): + state["/sync_ctxmgr"] = "generator_ctxmgr started" + yield state["/sync_ctxmgr"] + state["/sync_ctxmgr"] = "generator_ctxmgr completed" + + async def asyncgen_state_try(state: Dict[str, str] = Depends(get_state)): state["/async_raise"] = "asyncgen raise started" try: @@ -97,6 +114,16 @@ async def get_sync(state: str = Depends(generator_state)): return state +@app.get("/async_ctxmgr") +async def get_async_ctxmgr(state: str = Depends(asyncgen_state_ctxmgr)): + return state + + +@app.get("/sync_ctxmgr") +async def get_sync_ctxmgr(state: str = Depends(generator_state_ctxmgr)): + return state + + @app.get("/async_raise") async def get_async_raise(state: str = Depends(asyncgen_state_try)): assert state == "asyncgen raise started" @@ -230,6 +257,22 @@ def test_sync_state(): assert state["/sync"] == "generator completed" +def test_async_ctxmgr_state(): + assert state["/async_ctxmgr"] == "asyncgen_ctxmgr not started" + response = client.get("/async_ctxmgr") + assert response.status_code == 200, response.text + assert response.json() == "asyncgen_ctxmgr started" + assert state["/async_ctxmgr"] == "asyncgen_ctxmgr completed" + + +def test_sync_ctxmgr_state(): + assert state["/sync_ctxmgr"] == "generator_ctxmgr not started" + response = client.get("/sync_ctxmgr") + assert response.status_code == 200, response.text + assert response.json() == "generator_ctxmgr started" + assert state["/sync_ctxmgr"] == "generator_ctxmgr completed" + + def test_async_raise_other(): assert state["/async_raise"] == "asyncgen raise not started" with pytest.raises(OtherDependencyError):