From 0db992850395300d425d7baa48ed6891c41f70d9 Mon Sep 17 00:00:00 2001 From: Matthew Martin Date: Sat, 20 May 2023 12:55:37 -0500 Subject: [PATCH] Handle wrapped dependencies --- fastapi/dependencies/utils.py | 20 ++++++--- tests/test_dependency_wrapped.py | 77 ++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 6 deletions(-) create mode 100644 tests/test_dependency_wrapped.py diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index f131001ce..65d3ecdd5 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -534,11 +534,15 @@ def is_gen_callable(call: Callable[..., Any]) -> bool: async def solve_generator( - *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] + *, + call: Callable[..., Any], + unwrapped_call: Callable[..., Any], + stack: AsyncExitStack, + sub_values: Dict[str, Any], ) -> Any: - if is_gen_callable(call): + if is_gen_callable(unwrapped_call): cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) - elif is_async_gen_callable(call): + elif is_async_gen_callable(unwrapped_call): cm = asynccontextmanager(call)(**sub_values) return await stack.enter_async_context(cm) @@ -610,15 +614,19 @@ async def solve_dependencies( if sub_errors: errors.extend(sub_errors) continue + unwrapped_call = inspect.unwrap(call) if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: solved = dependency_cache[sub_dependant.cache_key] - elif is_gen_callable(call) or is_async_gen_callable(call): + elif is_gen_callable(unwrapped_call) or is_async_gen_callable(unwrapped_call): stack = request.scope.get("fastapi_astack") assert isinstance(stack, AsyncExitStack) solved = await solve_generator( - call=call, stack=stack, sub_values=sub_values + call=call, + unwrapped_call=unwrapped_call, + stack=stack, + sub_values=sub_values, ) - elif is_coroutine_callable(call): + elif is_coroutine_callable(unwrapped_call): solved = await call(**sub_values) else: solved = await run_in_threadpool(call, **sub_values) diff --git a/tests/test_dependency_wrapped.py b/tests/test_dependency_wrapped.py new file mode 100644 index 000000000..f581ccba4 --- /dev/null +++ b/tests/test_dependency_wrapped.py @@ -0,0 +1,77 @@ +from functools import wraps +from typing import AsyncGenerator, Generator + +import pytest +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient + + +def noop_wrap(func): + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +app = FastAPI() + + +@noop_wrap +def wrapped_dependency() -> bool: + return True + + +@noop_wrap +def wrapped_gen_dependency() -> Generator[bool, None, None]: + yield True + + +@noop_wrap +async def async_wrapped_dependency() -> bool: + return True + + +@noop_wrap +async def async_wrapped_gen_dependency() -> AsyncGenerator[bool, None]: + yield True + + +@app.get("/wrapped-dependency/") +async def get_wrapped_dependency(value: bool = Depends(wrapped_dependency)): + return value + + +@app.get("/wrapped-gen-dependency/") +async def get_wrapped_gen_dependency(value: bool = Depends(wrapped_gen_dependency)): + return value + + +@app.get("/async-wrapped-dependency/") +async def get_async_wrapped_dependency(value: bool = Depends(async_wrapped_dependency)): + return value + + +@app.get("/async-wrapped-gen-dependency/") +async def get_async_wrapped_gen_dependency( + value: bool = Depends(async_wrapped_gen_dependency), +): + return value + + +client = TestClient(app) + + +@pytest.mark.parametrize( + "route", + [ + "/wrapped-dependency", + "/wrapped-gen-dependency", + "/async-wrapped-dependency", + "/async-wrapped-gen-dependency", + ], +) +def test_class_dependency(route): + response = client.get(route) + assert response.status_code == 200, response.text + assert response.json() is True