diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 1a660f5d3..3ff7d3356 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -274,7 +274,7 @@ def get_dependant( path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters - if inspect.isgeneratorfunction(call) or inspect.isasyncgenfunction(call): + if is_gen_callable(call) or is_async_gen_callable(call): check_dependency_contextmanagers() dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache) for param_name, param in signature_params.items(): @@ -412,19 +412,41 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: def is_coroutine_callable(call: Callable) -> bool: if inspect.isroutine(call): - return asyncio.iscoroutinefunction(call) + return inspect.iscoroutinefunction(call) if inspect.isclass(call): return False call = getattr(call, "__call__", None) - return asyncio.iscoroutinefunction(call) + return inspect.iscoroutinefunction(call) + + +def is_async_gen_callable(call: Callable) -> bool: + if inspect.isasyncgenfunction(call): + return True + call = getattr(call, "__call__", None) + return inspect.isasyncgenfunction(call) + + +def is_gen_callable(call: Callable) -> bool: + if inspect.isgeneratorfunction(call): + return True + call = getattr(call, "__call__", None) + return inspect.isgeneratorfunction(call) async def solve_generator( *, call: Callable, stack: AsyncExitStack, sub_values: Dict[str, Any] ) -> Any: - if inspect.isgeneratorfunction(call): + if is_gen_callable(call): cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) - elif inspect.isasyncgenfunction(call): + elif is_async_gen_callable(call): + if not inspect.isasyncgenfunction(call): + # asynccontextmanager from the async_generator backfill pre python3.7 + # does not support callables that are not functions or methods. + # See https://github.com/python-trio/async_generator/issues/32 + # + # Expand the callable class into its __call__ method before decorating it. + # This approach will work on newer python versions as well. + call = getattr(call, "__call__", None) cm = asynccontextmanager(call)(**sub_values) return await stack.enter_async_context(cm) @@ -505,7 +527,7 @@ async def solve_dependencies( continue if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: solved = dependency_cache[sub_dependant.cache_key] - elif inspect.isgeneratorfunction(call) or inspect.isasyncgenfunction(call): + elif is_gen_callable(call) or is_async_gen_callable(call): stack = request.scope.get("fastapi_astack") if stack is None: raise RuntimeError( diff --git a/tests/test_dependency_class.py b/tests/test_dependency_class.py index ba2e3cfcf..bfe777f52 100644 --- a/tests/test_dependency_class.py +++ b/tests/test_dependency_class.py @@ -1,3 +1,5 @@ +from typing import AsyncGenerator, Generator + import pytest from fastapi import Depends, FastAPI from fastapi.testclient import TestClient @@ -10,11 +12,21 @@ class CallableDependency: return value +class CallableGenDependency: + def __call__(self, value: str) -> Generator[str, None, None]: + yield value + + class AsyncCallableDependency: async def __call__(self, value: str) -> str: return value +class AsyncCallableGenDependency: + async def __call__(self, value: str) -> AsyncGenerator[str, None]: + yield value + + class MethodsDependency: def synchronous(self, value: str) -> str: return value @@ -22,9 +34,17 @@ class MethodsDependency: async def asynchronous(self, value: str) -> str: return value + def synchronous_gen(self, value: str) -> Generator[str, None, None]: + yield value + + async def asynchronous_gen(self, value: str) -> AsyncGenerator[str, None]: + yield value + callable_dependency = CallableDependency() +callable_gen_dependency = CallableGenDependency() async_callable_dependency = AsyncCallableDependency() +async_callable_gen_dependency = AsyncCallableGenDependency() methods_dependency = MethodsDependency() @@ -33,11 +53,23 @@ async def get_callable_dependency(value: str = Depends(callable_dependency)): return value +@app.get("/callable-gen-dependency") +async def get_callable_gen_dependency(value: str = Depends(callable_gen_dependency)): + return value + + @app.get("/async-callable-dependency") async def get_callable_dependency(value: str = Depends(async_callable_dependency)): return value +@app.get("/async-callable-gen-dependency") +async def get_callable_gen_dependency( + value: str = Depends(async_callable_gen_dependency), +): + return value + + @app.get("/synchronous-method-dependency") async def get_synchronous_method_dependency( value: str = Depends(methods_dependency.synchronous), @@ -45,6 +77,13 @@ async def get_synchronous_method_dependency( return value +@app.get("/synchronous-method-gen-dependency") +async def get_synchronous_method_gen_dependency( + value: str = Depends(methods_dependency.synchronous_gen), +): + return value + + @app.get("/asynchronous-method-dependency") async def get_asynchronous_method_dependency( value: str = Depends(methods_dependency.asynchronous), @@ -52,6 +91,13 @@ async def get_asynchronous_method_dependency( return value +@app.get("/asynchronous-method-gen-dependency") +async def get_asynchronous_method_gen_dependency( + value: str = Depends(methods_dependency.asynchronous_gen), +): + return value + + client = TestClient(app) @@ -59,9 +105,13 @@ client = TestClient(app) "route,value", [ ("/callable-dependency", "callable-dependency"), + ("/callable-gen-dependency", "callable-gen-dependency"), ("/async-callable-dependency", "async-callable-dependency"), + ("/async-callable-gen-dependency", "async-callable-gen-dependency"), ("/synchronous-method-dependency", "synchronous-method-dependency"), + ("/synchronous-method-gen-dependency", "synchronous-method-gen-dependency"), ("/asynchronous-method-dependency", "asynchronous-method-dependency"), + ("/asynchronous-method-gen-dependency", "asynchronous-method-gen-dependency"), ], ) def test_class_dependency(route, value):