Browse Source

Merge 463419ef8d into 8af92a6139

pull/9555/merge
Matthew Martin 2 days ago
committed by GitHub
parent
commit
4d607c9b83
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 20
      fastapi/dependencies/utils.py
  2. 77
      tests/test_dependency_wrapped.py

20
fastapi/dependencies/utils.py

@ -551,11 +551,15 @@ def is_gen_callable(call: Callable[..., Any]) -> bool:
async def solve_generator(
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
*,
call: Callable[..., Any],
unwrapped_call: Callable[..., Any],
stack: AsyncExitStack,
sub_values: Dict[str, Any],
) -> Any:
if is_gen_callable(call):
if is_gen_callable(unwrapped_call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
elif is_async_gen_callable(call):
elif is_async_gen_callable(unwrapped_call):
cm = asynccontextmanager(call)(**sub_values)
return await stack.enter_async_context(cm)
@ -628,13 +632,17 @@ async def solve_dependencies(
if solved_result.errors:
errors.extend(solved_result.errors)
continue
unwrapped_call = inspect.unwrap(call)
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call):
elif is_gen_callable(unwrapped_call) or is_async_gen_callable(unwrapped_call):
solved = await solve_generator(
call=call, stack=async_exit_stack, sub_values=solved_result.values
call=call,
unwrapped_call=unwrapped_call,
stack=async_exit_stack,
sub_values=solved_result.values,
)
elif is_coroutine_callable(call):
elif is_coroutine_callable(unwrapped_call):
solved = await call(**solved_result.values)
else:
solved = await run_in_threadpool(call, **solved_result.values)

77
tests/test_dependency_wrapped.py

@ -0,0 +1,77 @@
from functools import wraps
from typing import AsyncGenerator, Generator
import pytest
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
def noop_wrap(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
app = FastAPI()
@noop_wrap
def wrapped_dependency() -> bool:
return True
@noop_wrap
def wrapped_gen_dependency() -> Generator[bool, None, None]:
yield True
@noop_wrap
async def async_wrapped_dependency() -> bool:
return True
@noop_wrap
async def async_wrapped_gen_dependency() -> AsyncGenerator[bool, None]:
yield True
@app.get("/wrapped-dependency/")
async def get_wrapped_dependency(value: bool = Depends(wrapped_dependency)):
return value
@app.get("/wrapped-gen-dependency/")
async def get_wrapped_gen_dependency(value: bool = Depends(wrapped_gen_dependency)):
return value
@app.get("/async-wrapped-dependency/")
async def get_async_wrapped_dependency(value: bool = Depends(async_wrapped_dependency)):
return value
@app.get("/async-wrapped-gen-dependency/")
async def get_async_wrapped_gen_dependency(
value: bool = Depends(async_wrapped_gen_dependency),
):
return value
client = TestClient(app)
@pytest.mark.parametrize(
"route",
[
"/wrapped-dependency",
"/wrapped-gen-dependency",
"/async-wrapped-dependency",
"/async-wrapped-gen-dependency",
],
)
def test_class_dependency(route):
response = client.get(route)
assert response.status_code == 200, response.text
assert response.json() is True
Loading…
Cancel
Save