diff --git a/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md b/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md
index 2b97ba39e..a6cf2f305 100644
--- a/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md
+++ b/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md
@@ -23,6 +23,8 @@ In fact, FastAPI uses those two decorators internally.
///
+And if you have a context manager already, you can use it as a dependency, too.
+
## A database dependency with `yield`
For example, you could use this to create a database session and close it after finishing.
@@ -240,36 +242,37 @@ When the `with` block finishes, it makes sure to close the file, even if there w
When you create a dependency with `yield`, **FastAPI** will internally create a context manager for it, and combine it with some other related tools.
-### Using context managers in dependencies with `yield`
+### Context Managers as dependencies
-/// warning
+You don’t have to create a Context Manager to use FastAPI dependencies.
-This is, more or less, an "advanced" idea.
+But sometimes you might want to use a dependency both inside and outside FastAPI. For example, you might want to connect to a database in a DB migration script. You can create a Context Manager manually then:
-If you are just starting with **FastAPI** you might want to skip it for now.
+{* ../../docs_src/dependencies/tutorial010_ctx.py *}
-///
+This way you can use it with `Depends()` in your app, and as a regular Context Manager everywhere else:
-In Python, you can create Context Managers by creating a class with two methods: `__enter__()` and `__exit__()`.
+{* ../../docs_src/dependencies/tutorial010_ctx_usage.py *}
-You can also use them inside of **FastAPI** dependencies with `yield` by using
-`with` or `async with` statements inside of the dependency function:
-{* ../../docs_src/dependencies/tutorial010.py hl[1:9,13] *}
+### Defining Context Managers as classes
-/// tip
+/// warning
-Another way to create a context manager is with:
+This is, more or less, an "advanced" idea.
-* `@contextlib.contextmanager` or
-* `@contextlib.asynccontextmanager`
+If you are just starting with **FastAPI** you might want to skip it for now.
+
+///
+
+You can also create Context Managers by creating a class with two methods: `__enter__()` and `__exit__()`.
-using them to decorate a function with a single `yield`.
+Like the Context Managers created with `contextlib`, you can use them inside of **FastAPI** dependencies directly:
-That's what **FastAPI** uses internally for dependencies with `yield`.
+{* ../../docs_src/dependencies/tutorial010.py hl[1:9,13] *}
-But you don't have to use the decorators for FastAPI dependencies (and you shouldn't).
+/// note | Technical Details
-FastAPI will do it for you internally.
+Internally, FastAPI calls the dependency function and then checks if the result has either an `__enter__()` or an `__aenter__()` method. If that’s the case, it will treat it as a context manager.
///
diff --git a/docs_src/dependencies/tutorial010.py b/docs_src/dependencies/tutorial010.py
index c27f1b170..3b0f024bc 100644
--- a/docs_src/dependencies/tutorial010.py
+++ b/docs_src/dependencies/tutorial010.py
@@ -9,6 +9,5 @@ class MySuperContextManager:
self.db.close()
-async def get_db():
- with MySuperContextManager() as db:
- yield db
+@app.get("/")
+async def get_root(db: Annotated[DBSession, Depends(MySuperContextManager)]): ...
diff --git a/docs_src/dependencies/tutorial010_ctx.py b/docs_src/dependencies/tutorial010_ctx.py
new file mode 100644
index 000000000..1fd180c8f
--- /dev/null
+++ b/docs_src/dependencies/tutorial010_ctx.py
@@ -0,0 +1,10 @@
+from contextlib import asynccontextmanager
+
+
+@asynccontextmanager
+async def get_db():
+ db = DBSession()
+ try:
+ yield db
+ finally:
+ db.close()
diff --git a/docs_src/dependencies/tutorial010_ctx_usage.py b/docs_src/dependencies/tutorial010_ctx_usage.py
new file mode 100644
index 000000000..11a3ccf2e
--- /dev/null
+++ b/docs_src/dependencies/tutorial010_ctx_usage.py
@@ -0,0 +1,3 @@
+async def migrate():
+ with get_db() as db:
+ ...
diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py
index 081b63a8b..1cbedcce9 100644
--- a/fastapi/dependencies/utils.py
+++ b/fastapi/dependencies/utils.py
@@ -637,7 +637,14 @@ async def solve_dependencies(
elif is_coroutine_callable(call):
solved = await call(**solved_result.values)
else:
- solved = await run_in_threadpool(call, **solved_result.values)
+ called = await run_in_threadpool(call, **solved_result.values)
+ if hasattr(called, "__aenter__"):
+ solved = await async_exit_stack.enter_async_context(called)
+ elif hasattr(called, "__enter__"):
+ cm = contextmanager_in_threadpool(called)
+ solved = await async_exit_stack.enter_async_context(cm)
+ else:
+ solved = called
if sub_dependant.name is not None:
values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:
diff --git a/pyproject.toml b/pyproject.toml
index 7709451ff..5d720e05d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -238,6 +238,8 @@ ignore = [
"docs_src/custom_request_and_route/tutorial002.py" = ["B904"]
"docs_src/dependencies/tutorial008_an.py" = ["F821"]
"docs_src/dependencies/tutorial008_an_py39.py" = ["F821"]
+"docs_src/dependencies/tutorial010_ctx.py" = ["F821"]
+"docs_src/dependencies/tutorial010_ctx_usage.py" = ["F821", "F841"]
"docs_src/query_params_str_validations/tutorial012_an.py" = ["B006"]
"docs_src/query_params_str_validations/tutorial012_an_py39.py" = ["B006"]
"docs_src/query_params_str_validations/tutorial013_an.py" = ["B006"]
diff --git a/tests/test_dependency_contextmanager.py b/tests/test_dependency_contextmanager.py
index 039c423b9..8849e8cf3 100644
--- a/tests/test_dependency_contextmanager.py
+++ b/tests/test_dependency_contextmanager.py
@@ -1,4 +1,5 @@
import json
+from contextlib import asynccontextmanager, contextmanager
from typing import Dict
import pytest
@@ -10,6 +11,8 @@ app = FastAPI()
state = {
"/async": "asyncgen not started",
"/sync": "generator not started",
+ "/async_ctxmgr": "asyncgen_ctxmgr not started",
+ "/sync_ctxmgr": "generator_ctxmgr not started",
"/async_raise": "asyncgen raise not started",
"/sync_raise": "generator raise not started",
"context_a": "not started a",
@@ -49,6 +52,20 @@ def generator_state(state: Dict[str, str] = Depends(get_state)):
state["/sync"] = "generator completed"
+@asynccontextmanager
+async def asyncgen_state_ctxmgr(state: Dict[str, str] = Depends(get_state)):
+ state["/async_ctxmgr"] = "asyncgen_ctxmgr started"
+ yield state["/async_ctxmgr"]
+ state["/async_ctxmgr"] = "asyncgen_ctxmgr completed"
+
+
+@contextmanager
+def generator_state_ctxmgr(state: Dict[str, str] = Depends(get_state)):
+ state["/sync_ctxmgr"] = "generator_ctxmgr started"
+ yield state["/sync_ctxmgr"]
+ state["/sync_ctxmgr"] = "generator_ctxmgr completed"
+
+
async def asyncgen_state_try(state: Dict[str, str] = Depends(get_state)):
state["/async_raise"] = "asyncgen raise started"
try:
@@ -97,6 +114,16 @@ async def get_sync(state: str = Depends(generator_state)):
return state
+@app.get("/async_ctxmgr")
+async def get_async_ctxmgr(state: str = Depends(asyncgen_state_ctxmgr)):
+ return state
+
+
+@app.get("/sync_ctxmgr")
+async def get_sync_ctxmgr(state: str = Depends(generator_state_ctxmgr)):
+ return state
+
+
@app.get("/async_raise")
async def get_async_raise(state: str = Depends(asyncgen_state_try)):
assert state == "asyncgen raise started"
@@ -230,6 +257,22 @@ def test_sync_state():
assert state["/sync"] == "generator completed"
+def test_async_ctxmgr_state():
+ assert state["/async_ctxmgr"] == "asyncgen_ctxmgr not started"
+ response = client.get("/async_ctxmgr")
+ assert response.status_code == 200, response.text
+ assert response.json() == "asyncgen_ctxmgr started"
+ assert state["/async_ctxmgr"] == "asyncgen_ctxmgr completed"
+
+
+def test_sync_ctxmgr_state():
+ assert state["/sync_ctxmgr"] == "generator_ctxmgr not started"
+ response = client.get("/sync_ctxmgr")
+ assert response.status_code == 200, response.text
+ assert response.json() == "generator_ctxmgr started"
+ assert state["/sync_ctxmgr"] == "generator_ctxmgr completed"
+
+
def test_async_raise_other():
assert state["/async_raise"] == "asyncgen raise not started"
with pytest.raises(OtherDependencyError):