From f3ddc7bdeb72f6899927a90f173a251764af618b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= <fvoron@kpi-intelligence.com>
Date: Wed, 27 Nov 2019 20:51:30 +0100
Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Allow=20async=20class=20methods?=
 =?UTF-8?q?=20as=20dependencies=20(#681)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 fastapi/dependencies/utils.py  |  2 +-
 tests/test_dependency_class.py | 70 ++++++++++++++++++++++++++++++++++
 2 files changed, 71 insertions(+), 1 deletion(-)
 create mode 100644 tests/test_dependency_class.py

diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py
index 54e274762..956fffff4 100644
--- a/fastapi/dependencies/utils.py
+++ b/fastapi/dependencies/utils.py
@@ -412,7 +412,7 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
 
 
 def is_coroutine_callable(call: Callable) -> bool:
-    if inspect.isfunction(call):
+    if inspect.isroutine(call):
         return asyncio.iscoroutinefunction(call)
     if inspect.isclass(call):
         return False
diff --git a/tests/test_dependency_class.py b/tests/test_dependency_class.py
new file mode 100644
index 000000000..db1f5cc8f
--- /dev/null
+++ b/tests/test_dependency_class.py
@@ -0,0 +1,70 @@
+import pytest
+from fastapi import Depends, FastAPI
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+
+class CallableDependency:
+    def __call__(self, value: str) -> str:
+        return value
+
+
+class AsyncCallableDependency:
+    async def __call__(self, value: str) -> str:
+        return value
+
+
+class MethodsDependency:
+    def synchronous(self, value: str) -> str:
+        return value
+
+    async def asynchronous(self, value: str) -> str:
+        return value
+
+
+callable_dependency = CallableDependency()
+async_callable_dependency = AsyncCallableDependency()
+methods_dependency = MethodsDependency()
+
+
+@app.get("/callable-dependency")
+async def get_callable_dependency(value: str = Depends(callable_dependency)):
+    return value
+
+
+@app.get("/async-callable-dependency")
+async def get_callable_dependency(value: str = Depends(async_callable_dependency)):
+    return value
+
+
+@app.get("/synchronous-method-dependency")
+async def get_synchronous_method_dependency(
+    value: str = Depends(methods_dependency.synchronous),
+):
+    return value
+
+
+@app.get("/asynchronous-method-dependency")
+async def get_asynchronous_method_dependency(
+    value: str = Depends(methods_dependency.asynchronous),
+):
+    return value
+
+
+client = TestClient(app)
+
+
+@pytest.mark.parametrize(
+    "route,value",
+    [
+        ("/callable-dependency", "callable-dependency"),
+        ("/async-callable-dependency", "async-callable-dependency"),
+        ("/synchronous-method-dependency", "synchronous-method-dependency"),
+        ("/asynchronous-method-dependency", "asynchronous-method-dependency"),
+    ],
+)
+def test_class_dependency(route, value):
+    response = client.get(route, params={"value": value})
+    assert response.status_code == 200
+    assert response.json() == value