diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 081b63a8b..77005df09 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -551,11 +551,15 @@ def is_gen_callable(call: Callable[..., Any]) -> bool: async def solve_generator( - *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] + *, + call: Callable[..., Any], + unwrapped_call: Callable[..., Any], + stack: AsyncExitStack, + sub_values: Dict[str, Any], ) -> Any: - if is_gen_callable(call): + if is_gen_callable(unwrapped_call): cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) - elif is_async_gen_callable(call): + elif is_async_gen_callable(unwrapped_call): cm = asynccontextmanager(call)(**sub_values) return await stack.enter_async_context(cm) @@ -628,13 +632,17 @@ async def solve_dependencies( if solved_result.errors: errors.extend(solved_result.errors) continue + unwrapped_call = inspect.unwrap(call) if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: solved = dependency_cache[sub_dependant.cache_key] - elif is_gen_callable(call) or is_async_gen_callable(call): + elif is_gen_callable(unwrapped_call) or is_async_gen_callable(unwrapped_call): solved = await solve_generator( - call=call, stack=async_exit_stack, sub_values=solved_result.values + call=call, + unwrapped_call=unwrapped_call, + stack=async_exit_stack, + sub_values=solved_result.values, ) - elif is_coroutine_callable(call): + elif is_coroutine_callable(unwrapped_call): solved = await call(**solved_result.values) else: solved = await run_in_threadpool(call, **solved_result.values) diff --git a/tests/test_dependency_wrapped.py b/tests/test_dependency_wrapped.py new file mode 100644 index 000000000..f581ccba4 --- /dev/null +++ b/tests/test_dependency_wrapped.py @@ -0,0 +1,77 @@ +from functools import wraps +from typing import AsyncGenerator, Generator + +import pytest +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient + + +def noop_wrap(func): + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +app = FastAPI() + + +@noop_wrap +def wrapped_dependency() -> bool: + return True + + +@noop_wrap +def wrapped_gen_dependency() -> Generator[bool, None, None]: + yield True + + +@noop_wrap +async def async_wrapped_dependency() -> bool: + return True + + +@noop_wrap +async def async_wrapped_gen_dependency() -> AsyncGenerator[bool, None]: + yield True + + +@app.get("/wrapped-dependency/") +async def get_wrapped_dependency(value: bool = Depends(wrapped_dependency)): + return value + + +@app.get("/wrapped-gen-dependency/") +async def get_wrapped_gen_dependency(value: bool = Depends(wrapped_gen_dependency)): + return value + + +@app.get("/async-wrapped-dependency/") +async def get_async_wrapped_dependency(value: bool = Depends(async_wrapped_dependency)): + return value + + +@app.get("/async-wrapped-gen-dependency/") +async def get_async_wrapped_gen_dependency( + value: bool = Depends(async_wrapped_gen_dependency), +): + return value + + +client = TestClient(app) + + +@pytest.mark.parametrize( + "route", + [ + "/wrapped-dependency", + "/wrapped-gen-dependency", + "/async-wrapped-dependency", + "/async-wrapped-gen-dependency", + ], +) +def test_class_dependency(route): + response = client.get(route) + assert response.status_code == 200, response.text + assert response.json() is True