Browse Source

Merge 9338618f5f into 6df50d40fe

pull/13607/merge
Aleksandr Sulimov 3 days ago
committed by GitHub
parent
commit
ca430318b0
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 17
      docs/en/docs/tutorial/dependencies/dependencies-with-yield.md
  2. 23
      docs_src/dependencies/tutorial008e.py
  3. 201
      fastapi/concurrency.py
  4. 23
      fastapi/dependencies/utils.py
  5. 3
      fastapi/routing.py
  6. 73
      tests/test_dependency_response.py
  7. 22
      tests/test_tutorial/test_dependencies/test_tutorial008e.py

17
docs/en/docs/tutorial/dependencies/dependencies-with-yield.md

@ -19,7 +19,7 @@ Any function that is valid to use with:
would be valid to use as a **FastAPI** dependency.
In fact, FastAPI uses those two decorators internally.
In fact, FastAPI uses similar decorator internally.
///
@ -129,6 +129,19 @@ You can re-raise the same exception using `raise`:
Now the client will get the same *HTTP 500 Internal Server Error* response, but the server will have our custom `InternalError` in the logs. 😎
## Modifying response in dependencies after `yield`
If you want to read or change response in a dependency with `yield` after `yield`, you can get your endpoint's response
if you assign `yield` to it:
{* ../../docs_src/dependencies/tutorial008e.py hl[13:14] *}
/// warning
You cannot use dependency-injected `Response` to modify response in dependency after `yield`.
///
## Execution of dependencies with `yield`
The sequence of execution is more or less like this diagram. Time flows from top to bottom. And each column is one of the parts interacting or executing code.
@ -266,7 +279,7 @@ Another way to create a context manager is with:
using them to decorate a function with a single `yield`.
That's what **FastAPI** uses internally for dependencies with `yield`.
A similar decorator is what **FastAPI** uses internally for dependencies with `yield`.
But you don't have to use the decorators for FastAPI dependencies (and you shouldn't).

23
docs_src/dependencies/tutorial008e.py

@ -0,0 +1,23 @@
import json
from fastapi import Depends, FastAPI, HTTPException
app = FastAPI()
data = {
"plumbus": {"description": "Freshly pickled plumbus", "owner": "Morty"},
"portal-gun": {"description": "Gun to create portals", "owner": "Rick"},
}
def set_username():
response = yield
response.headers["X-Username"] = json.loads(response.body)["owner"]
@app.get("/items/{item_id}", dependencies=[Depends(set_username)])
def get_item(item_id: str):
if item_id not in data:
raise HTTPException(status_code=404, detail="Item not found")
return data[item_id]

201
fastapi/concurrency.py

@ -1,5 +1,15 @@
from contextlib import asynccontextmanager as asynccontextmanager
from typing import AsyncGenerator, ContextManager, TypeVar
from types import TracebackType
from typing import (
Any,
AsyncGenerator,
ContextManager,
Generator,
Optional,
Type,
TypeVar,
Union,
)
import anyio.to_thread
from anyio import CapacityLimiter
@ -13,9 +23,9 @@ _T = TypeVar("_T")
@asynccontextmanager
async def contextmanager_in_threadpool(
async def contextmanager_in_threadpool( # not used, kept for backwards compatibility
cm: ContextManager[_T],
) -> AsyncGenerator[_T, None]:
) -> AsyncGenerator[_T, None]: # pragma: no cover
# blocking __exit__ from running waiting on a free thread
# can create race conditions/deadlocks if the context manager itself
# has its own internal pool (e.g. a database connection pool)
@ -37,3 +47,188 @@ async def contextmanager_in_threadpool(
await anyio.to_thread.run_sync(
cm.__exit__, None, None, None, limiter=exit_limiter
)
class _StopIteration(Exception):
pass
class ContextManagerFromGenerator:
"""Create a context manager from a generator.
It handles both sync and async generators. Generator has to have exactly one yield.
Implementation is based on contextlib's contextmanager.
Additionally, as apposed to contextlib.contextmanager/contextlib.asynccontextmanager,
this context manager allows to call `asend` on underlaying generator and gracefully handle
this scenario in __aexit__.
Instances of this class cannot be reused.
"""
def __init__(self, gen: Union[AsyncGenerator[_T, None], Generator[_T, None, None]]):
self.gen = gen
self._has_started = False
self._has_executed = False
@staticmethod
def _send(gen: Generator[_T, None, None], value: Any) -> Any:
# We can't raise `StopIteration` from within the threadpool executor
# and catch it outside that context, so we coerce them into a different
# exception type.
try:
return gen.send(value)
except StopIteration:
raise _StopIteration from None
@staticmethod
def _throw(gen: Generator[_T, None, None], value: Any) -> Any:
# We can't raise `StopIteration` from within the threadpool executor
# and catch it outside that context, so we coerce them into a different
# exception type.
try:
return gen.throw(value)
except StopIteration:
raise _StopIteration from None
async def __aenter__(self) -> Any:
try:
if isinstance(self.gen, Generator):
result = await run_in_threadpool(self._send, self.gen, None)
else:
result = await self.gen.asend(None)
self._has_started = True
return result
except (_StopIteration, StopAsyncIteration):
raise RuntimeError("generator didn't yield") from None # pragma: no cover
async def asend(self, value: Any) -> None:
if self._has_executed:
raise RuntimeError(
"ContextManagerFromGenerator can only be used once"
) # pragma: no cover
if not self._has_started:
raise RuntimeError(
"ContextManagerFromGenerator has not been entered"
) # pragma: no cover
self._has_executed = True
try:
if isinstance(self.gen, Generator):
await run_in_threadpool(self._send, self.gen, value)
else:
await self.gen.asend(value)
except (_StopIteration, StopAsyncIteration):
return
else: # pragma: no cover
try:
raise RuntimeError("generator didn't stop")
finally:
if isinstance(self.gen, Generator):
self.gen.close()
else:
await self.gen.aclose()
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> bool:
if self._has_executed:
if isinstance(self.gen, Generator):
self.gen.close()
return False
await self.gen.aclose()
return False
# blocking __exit__ from running waiting on a free thread
# can create race conditions/deadlocks if the context manager itself
# has its own internal pool (e.g. a database connection pool)
# to avoid this we let __exit__ run without a capacity limit
# since we're creating a new limiter for each call, any non-zero limit
# works (1 is arbitrary)
exit_limiter = CapacityLimiter(1)
if exc_type is None: # pragma: no cover
# usually shouldn't happen, as we call asend
try:
if isinstance(self.gen, Generator):
await anyio.to_thread.run_sync(
self._send,
self.gen,
None,
limiter=exit_limiter,
)
else:
await self.gen.asend(None)
except (_StopIteration, StopAsyncIteration):
return False
else: # pragma: no cover
try:
raise RuntimeError("generator didn't stop")
finally:
if isinstance(self.gen, Generator):
self.gen.close()
else:
await self.gen.aclose()
if exc_value is None: # pragma: no cover
# Need to force instantiation so we can reliably
# tell if we get the same exception back
exc_value = exc_type()
try:
if isinstance(self.gen, Generator):
await anyio.to_thread.run_sync(
self._throw,
self.gen,
exc_value,
limiter=exit_limiter,
)
else:
await self.gen.athrow(exc_value)
except (StopIteration, _StopIteration, StopAsyncIteration) as exc:
# Suppress Stop(Async)Iteration *unless* it's the same exception that
# was passed to throw(). This prevents a Stop(Async)Iteration
# raised inside the "with" statement from being suppressed.
return exc is not exc_value
except RuntimeError as exc: # pragma: no cover
# Don't re-raise the passed in exception. (issue27122)
if exc is exc_value:
exc.__traceback__ = traceback
return False
# Avoid suppressing if a Stop(Async)Iteration exception
# was passed to athrow() and later wrapped into a RuntimeError
# (see PEP 479 for sync generators; async generators also
# have this behavior). But do this only if the exception wrapped
# by the RuntimeError is actually Stop(Async)Iteration (see
# issue29692).
if (
isinstance(
exc_value,
(StopIteration, _StopIteration, StopAsyncIteration),
)
and exc.__cause__ is exc_value
):
exc_value.__traceback__ = traceback
return False
raise
except BaseException as exc:
# only re-raise if it's *not* the exception that was
# passed to throw(), because __exit__() must not raise
# an exception unless __exit__() itself failed. But throw()
# has to raise the exception to signal propagation, so this
# fixes the impedance mismatch between the throw() protocol
# and the __exit__() protocol.
if exc is not exc_value:
raise
exc.__traceback__ = traceback
return False
try: # pragma: no cover
raise RuntimeError("generator didn't stop after athrow()")
finally: # pragma: no cover
if isinstance(self.gen, Generator):
self.gen.close()
else:
await self.gen.aclose()

23
fastapi/dependencies/utils.py

@ -1,5 +1,5 @@
import inspect
from contextlib import AsyncExitStack, contextmanager
from contextlib import AsyncExitStack
from copy import copy, deepcopy
from dataclasses import dataclass
from typing import (
@ -47,10 +47,7 @@ from fastapi._compat import (
value_is_sequence,
)
from fastapi.background import BackgroundTasks
from fastapi.concurrency import (
asynccontextmanager,
contextmanager_in_threadpool,
)
from fastapi.concurrency import ContextManagerFromGenerator
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
@ -552,12 +549,9 @@ def is_gen_callable(call: Callable[..., Any]) -> bool:
async def solve_generator(
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
) -> Any:
if is_gen_callable(call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
elif is_async_gen_callable(call):
cm = asynccontextmanager(call)(**sub_values)
return await stack.enter_async_context(cm)
) -> Tuple[Any, Callable[[Response], Coroutine[None, None, None]]]:
cm = ContextManagerFromGenerator(call(**sub_values))
return (await stack.enter_async_context(cm), cm.asend)
@dataclass
@ -567,6 +561,7 @@ class SolvedDependency:
background_tasks: Optional[StarletteBackgroundTasks]
response: Response
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
generators_callbacks: List[Callable[[Response], Coroutine[None, None, None]]]
async def solve_dependencies(
@ -589,6 +584,7 @@ async def solve_dependencies(
response.status_code = None # type: ignore
dependency_cache = dependency_cache or {}
sub_dependant: Dependant
generators_callbacks: List[Callable[[Response], Coroutine[None, None, None]]] = []
for sub_dependant in dependant.dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast(
@ -625,15 +621,17 @@ async def solve_dependencies(
)
background_tasks = solved_result.background_tasks
dependency_cache.update(solved_result.dependency_cache)
generators_callbacks.extend(solved_result.generators_callbacks)
if solved_result.errors:
errors.extend(solved_result.errors)
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call):
solved = await solve_generator(
solved, callback = await solve_generator(
call=call, stack=async_exit_stack, sub_values=solved_result.values
)
generators_callbacks.append(callback)
elif is_coroutine_callable(call):
solved = await call(**solved_result.values)
else:
@ -692,6 +690,7 @@ async def solve_dependencies(
background_tasks=background_tasks,
response=response,
dependency_cache=dependency_cache,
generators_callbacks=generators_callbacks,
)

3
fastapi/routing.py

@ -340,6 +340,9 @@ def get_request_handler(
if not is_body_allowed_for_status_code(response.status_code):
response.body = b""
response.headers.raw.extend(solved_result.response.headers.raw)
# pass response to generator dependencies
for callback in reversed(solved_result.generators_callbacks):
await callback(response)
if errors:
validation_error = RequestValidationError(
_normalize_errors(errors), body=body

73
tests/test_dependency_response.py

@ -0,0 +1,73 @@
from typing import AsyncGenerator, Generator
from fastapi import Depends, FastAPI, Response
from fastapi.testclient import TestClient
app = FastAPI()
def dependency_gen(value: str) -> Generator[str, None, None]:
response: Response = yield
assert isinstance(response, Response)
assert response.status_code == 200
assert response.body == b'"response_get_dependency_gen"'
response.headers["X-Test"] = value
async def dependency_async_gen(value: str) -> AsyncGenerator[str, None]:
response = yield
assert isinstance(response, Response)
assert response.status_code == 200
assert response.body == b'"response_get_dependency_async_gen"'
response.headers["X-Test"] = value
async def sub_dependency_async_gen(value: str) -> AsyncGenerator[str, None]:
response = yield
assert isinstance(response, Response)
response.status_code = 201
assert response.body == b'"response_get_sub_dependency_async_gen"'
response.headers["X-Test"] = value
async def parent_dependency(result=Depends(sub_dependency_async_gen)):
return result
@app.get("/dependency-gen", dependencies=[Depends(dependency_gen)])
async def get_dependency_gen():
return "response_get_dependency_gen"
@app.get("/dependency-async-gen", dependencies=[Depends(dependency_async_gen)])
async def get_dependency_async_gen():
return "response_get_dependency_async_gen"
@app.get("/sub-dependency-gen", dependencies=[Depends(parent_dependency)])
async def get_sub_dependency_gen():
return "response_get_sub_dependency_async_gen"
client = TestClient(app)
def test_dependency_gen():
response = client.get("/dependency-gen", params={"value": "test"})
assert response.status_code == 200
assert response.content == b'"response_get_dependency_gen"'
assert response.headers["X-Test"] == "test"
def test_dependency_async_gen():
response = client.get("/dependency-async-gen", params={"value": "test"})
assert response.status_code == 200
assert response.content == b'"response_get_dependency_async_gen"'
assert response.headers["X-Test"] == "test"
def test_sub_dependency_gen():
response = client.get("/sub-dependency-gen", params={"value": "test"})
assert response.status_code == 201
assert response.content == b'"response_get_sub_dependency_async_gen"'
assert response.headers["X-Test"] == "test"

22
tests/test_tutorial/test_dependencies/test_tutorial008e.py

@ -0,0 +1,22 @@
from fastapi.testclient import TestClient
from docs_src.dependencies.tutorial008e import app
client = TestClient(app)
def test_get_no_item():
response = client.get("/items/foo")
assert response.status_code == 404, response.text
assert response.json() == {"detail": "Item not found"}
assert "X-Username" not in response.headers
def test_get():
response = client.get("/items/plumbus")
assert response.status_code == 200, response.text
assert response.json() == {
"description": "Freshly pickled plumbus",
"owner": "Morty",
}
assert response.headers["X-Username"] == "Morty"
Loading…
Cancel
Save