diff --git a/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md b/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md index 2b97ba39e..5b33b6828 100644 --- a/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md +++ b/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md @@ -19,7 +19,7 @@ Any function that is valid to use with: would be valid to use as a **FastAPI** dependency. -In fact, FastAPI uses those two decorators internally. +In fact, FastAPI uses similar decorator internally. /// @@ -129,6 +129,19 @@ You can re-raise the same exception using `raise`: Now the client will get the same *HTTP 500 Internal Server Error* response, but the server will have our custom `InternalError` in the logs. 😎 +## Modifying response in dependencies after `yield` + +If you want to read or change response in a dependency with `yield` after `yield`, you can get your endpoint's response +if you assign `yield` to it: + +{* ../../docs_src/dependencies/tutorial008e.py hl[13:14] *} + +/// warning + +You cannot use dependency-injected `Response` to modify response in dependency after `yield`. + +/// + ## Execution of dependencies with `yield` The sequence of execution is more or less like this diagram. Time flows from top to bottom. And each column is one of the parts interacting or executing code. @@ -266,7 +279,7 @@ Another way to create a context manager is with: using them to decorate a function with a single `yield`. -That's what **FastAPI** uses internally for dependencies with `yield`. +A similar decorator is what **FastAPI** uses internally for dependencies with `yield`. But you don't have to use the decorators for FastAPI dependencies (and you shouldn't). diff --git a/docs_src/dependencies/tutorial008e.py b/docs_src/dependencies/tutorial008e.py new file mode 100644 index 000000000..bf634233d --- /dev/null +++ b/docs_src/dependencies/tutorial008e.py @@ -0,0 +1,23 @@ +import json + +from fastapi import Depends, FastAPI, HTTPException + +app = FastAPI() + + +data = { + "plumbus": {"description": "Freshly pickled plumbus", "owner": "Morty"}, + "portal-gun": {"description": "Gun to create portals", "owner": "Rick"}, +} + + +def set_username(): + response = yield + response.headers["X-Username"] = json.loads(response.body)["owner"] + + +@app.get("/items/{item_id}", dependencies=[Depends(set_username)]) +def get_item(item_id: str): + if item_id not in data: + raise HTTPException(status_code=404, detail="Item not found") + return data[item_id] diff --git a/fastapi/concurrency.py b/fastapi/concurrency.py index 3202c7078..643e09d47 100644 --- a/fastapi/concurrency.py +++ b/fastapi/concurrency.py @@ -1,5 +1,15 @@ from contextlib import asynccontextmanager as asynccontextmanager -from typing import AsyncGenerator, ContextManager, TypeVar +from types import TracebackType +from typing import ( + Any, + AsyncGenerator, + ContextManager, + Generator, + Optional, + Type, + TypeVar, + Union, +) import anyio.to_thread from anyio import CapacityLimiter @@ -13,9 +23,9 @@ _T = TypeVar("_T") @asynccontextmanager -async def contextmanager_in_threadpool( +async def contextmanager_in_threadpool( # not used, kept for backwards compatibility cm: ContextManager[_T], -) -> AsyncGenerator[_T, None]: +) -> AsyncGenerator[_T, None]: # pragma: no cover # blocking __exit__ from running waiting on a free thread # can create race conditions/deadlocks if the context manager itself # has its own internal pool (e.g. a database connection pool) @@ -37,3 +47,188 @@ async def contextmanager_in_threadpool( await anyio.to_thread.run_sync( cm.__exit__, None, None, None, limiter=exit_limiter ) + + +class _StopIteration(Exception): + pass + + +class ContextManagerFromGenerator: + """Create a context manager from a generator. + + It handles both sync and async generators. Generator has to have exactly one yield. + + Implementation is based on contextlib's contextmanager. + + Additionally, as apposed to contextlib.contextmanager/contextlib.asynccontextmanager, + this context manager allows to call `asend` on underlaying generator and gracefully handle + this scenario in __aexit__. + + Instances of this class cannot be reused. + """ + + def __init__(self, gen: Union[AsyncGenerator[_T, None], Generator[_T, None, None]]): + self.gen = gen + self._has_started = False + self._has_executed = False + + @staticmethod + def _send(gen: Generator[_T, None, None], value: Any) -> Any: + # We can't raise `StopIteration` from within the threadpool executor + # and catch it outside that context, so we coerce them into a different + # exception type. + try: + return gen.send(value) + except StopIteration: + raise _StopIteration from None + + @staticmethod + def _throw(gen: Generator[_T, None, None], value: Any) -> Any: + # We can't raise `StopIteration` from within the threadpool executor + # and catch it outside that context, so we coerce them into a different + # exception type. + try: + return gen.throw(value) + except StopIteration: + raise _StopIteration from None + + async def __aenter__(self) -> Any: + try: + if isinstance(self.gen, Generator): + result = await run_in_threadpool(self._send, self.gen, None) + else: + result = await self.gen.asend(None) + self._has_started = True + return result + except (_StopIteration, StopAsyncIteration): + raise RuntimeError("generator didn't yield") from None # pragma: no cover + + async def asend(self, value: Any) -> None: + if self._has_executed: + raise RuntimeError( + "ContextManagerFromGenerator can only be used once" + ) # pragma: no cover + if not self._has_started: + raise RuntimeError( + "ContextManagerFromGenerator has not been entered" + ) # pragma: no cover + self._has_executed = True + try: + if isinstance(self.gen, Generator): + await run_in_threadpool(self._send, self.gen, value) + else: + await self.gen.asend(value) + except (_StopIteration, StopAsyncIteration): + return + else: # pragma: no cover + try: + raise RuntimeError("generator didn't stop") + finally: + if isinstance(self.gen, Generator): + self.gen.close() + else: + await self.gen.aclose() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> bool: + if self._has_executed: + if isinstance(self.gen, Generator): + self.gen.close() + return False + await self.gen.aclose() + return False + + # blocking __exit__ from running waiting on a free thread + # can create race conditions/deadlocks if the context manager itself + # has its own internal pool (e.g. a database connection pool) + # to avoid this we let __exit__ run without a capacity limit + # since we're creating a new limiter for each call, any non-zero limit + # works (1 is arbitrary) + exit_limiter = CapacityLimiter(1) + + if exc_type is None: # pragma: no cover + # usually shouldn't happen, as we call asend + try: + if isinstance(self.gen, Generator): + await anyio.to_thread.run_sync( + self._send, + self.gen, + None, + limiter=exit_limiter, + ) + else: + await self.gen.asend(None) + except (_StopIteration, StopAsyncIteration): + return False + else: # pragma: no cover + try: + raise RuntimeError("generator didn't stop") + finally: + if isinstance(self.gen, Generator): + self.gen.close() + else: + await self.gen.aclose() + + if exc_value is None: # pragma: no cover + # Need to force instantiation so we can reliably + # tell if we get the same exception back + exc_value = exc_type() + + try: + if isinstance(self.gen, Generator): + await anyio.to_thread.run_sync( + self._throw, + self.gen, + exc_value, + limiter=exit_limiter, + ) + else: + await self.gen.athrow(exc_value) + except (StopIteration, _StopIteration, StopAsyncIteration) as exc: + # Suppress Stop(Async)Iteration *unless* it's the same exception that + # was passed to throw(). This prevents a Stop(Async)Iteration + # raised inside the "with" statement from being suppressed. + return exc is not exc_value + except RuntimeError as exc: # pragma: no cover + # Don't re-raise the passed in exception. (issue27122) + if exc is exc_value: + exc.__traceback__ = traceback + return False + # Avoid suppressing if a Stop(Async)Iteration exception + # was passed to athrow() and later wrapped into a RuntimeError + # (see PEP 479 for sync generators; async generators also + # have this behavior). But do this only if the exception wrapped + # by the RuntimeError is actually Stop(Async)Iteration (see + # issue29692). + if ( + isinstance( + exc_value, + (StopIteration, _StopIteration, StopAsyncIteration), + ) + and exc.__cause__ is exc_value + ): + exc_value.__traceback__ = traceback + return False + raise + except BaseException as exc: + # only re-raise if it's *not* the exception that was + # passed to throw(), because __exit__() must not raise + # an exception unless __exit__() itself failed. But throw() + # has to raise the exception to signal propagation, so this + # fixes the impedance mismatch between the throw() protocol + # and the __exit__() protocol. + if exc is not exc_value: + raise + exc.__traceback__ = traceback + return False + try: # pragma: no cover + raise RuntimeError("generator didn't stop after athrow()") + finally: # pragma: no cover + if isinstance(self.gen, Generator): + self.gen.close() + else: + await self.gen.aclose() diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 081b63a8b..3e43906c7 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,5 +1,5 @@ import inspect -from contextlib import AsyncExitStack, contextmanager +from contextlib import AsyncExitStack from copy import copy, deepcopy from dataclasses import dataclass from typing import ( @@ -47,10 +47,7 @@ from fastapi._compat import ( value_is_sequence, ) from fastapi.background import BackgroundTasks -from fastapi.concurrency import ( - asynccontextmanager, - contextmanager_in_threadpool, -) +from fastapi.concurrency import ContextManagerFromGenerator from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.logger import logger from fastapi.security.base import SecurityBase @@ -552,12 +549,9 @@ def is_gen_callable(call: Callable[..., Any]) -> bool: async def solve_generator( *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] -) -> Any: - if is_gen_callable(call): - cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) - elif is_async_gen_callable(call): - cm = asynccontextmanager(call)(**sub_values) - return await stack.enter_async_context(cm) +) -> Tuple[Any, Callable[[Response], Coroutine[None, None, None]]]: + cm = ContextManagerFromGenerator(call(**sub_values)) + return (await stack.enter_async_context(cm), cm.asend) @dataclass @@ -567,6 +561,7 @@ class SolvedDependency: background_tasks: Optional[StarletteBackgroundTasks] response: Response dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] + generators_callbacks: List[Callable[[Response], Coroutine[None, None, None]]] async def solve_dependencies( @@ -589,6 +584,7 @@ async def solve_dependencies( response.status_code = None # type: ignore dependency_cache = dependency_cache or {} sub_dependant: Dependant + generators_callbacks: List[Callable[[Response], Coroutine[None, None, None]]] = [] for sub_dependant in dependant.dependencies: sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) sub_dependant.cache_key = cast( @@ -625,15 +621,17 @@ async def solve_dependencies( ) background_tasks = solved_result.background_tasks dependency_cache.update(solved_result.dependency_cache) + generators_callbacks.extend(solved_result.generators_callbacks) if solved_result.errors: errors.extend(solved_result.errors) continue if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: solved = dependency_cache[sub_dependant.cache_key] elif is_gen_callable(call) or is_async_gen_callable(call): - solved = await solve_generator( + solved, callback = await solve_generator( call=call, stack=async_exit_stack, sub_values=solved_result.values ) + generators_callbacks.append(callback) elif is_coroutine_callable(call): solved = await call(**solved_result.values) else: @@ -692,6 +690,7 @@ async def solve_dependencies( background_tasks=background_tasks, response=response, dependency_cache=dependency_cache, + generators_callbacks=generators_callbacks, ) diff --git a/fastapi/routing.py b/fastapi/routing.py index 54c75a027..f30b125e5 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -340,6 +340,9 @@ def get_request_handler( if not is_body_allowed_for_status_code(response.status_code): response.body = b"" response.headers.raw.extend(solved_result.response.headers.raw) + # pass response to generator dependencies + for callback in reversed(solved_result.generators_callbacks): + await callback(response) if errors: validation_error = RequestValidationError( _normalize_errors(errors), body=body diff --git a/tests/test_dependency_response.py b/tests/test_dependency_response.py new file mode 100644 index 000000000..6b2ad4d4f --- /dev/null +++ b/tests/test_dependency_response.py @@ -0,0 +1,73 @@ +from typing import AsyncGenerator, Generator + +from fastapi import Depends, FastAPI, Response +from fastapi.testclient import TestClient + +app = FastAPI() + + +def dependency_gen(value: str) -> Generator[str, None, None]: + response: Response = yield + assert isinstance(response, Response) + assert response.status_code == 200 + assert response.body == b'"response_get_dependency_gen"' + response.headers["X-Test"] = value + + +async def dependency_async_gen(value: str) -> AsyncGenerator[str, None]: + response = yield + assert isinstance(response, Response) + assert response.status_code == 200 + assert response.body == b'"response_get_dependency_async_gen"' + response.headers["X-Test"] = value + + +async def sub_dependency_async_gen(value: str) -> AsyncGenerator[str, None]: + response = yield + assert isinstance(response, Response) + response.status_code = 201 + assert response.body == b'"response_get_sub_dependency_async_gen"' + response.headers["X-Test"] = value + + +async def parent_dependency(result=Depends(sub_dependency_async_gen)): + return result + + +@app.get("/dependency-gen", dependencies=[Depends(dependency_gen)]) +async def get_dependency_gen(): + return "response_get_dependency_gen" + + +@app.get("/dependency-async-gen", dependencies=[Depends(dependency_async_gen)]) +async def get_dependency_async_gen(): + return "response_get_dependency_async_gen" + + +@app.get("/sub-dependency-gen", dependencies=[Depends(parent_dependency)]) +async def get_sub_dependency_gen(): + return "response_get_sub_dependency_async_gen" + + +client = TestClient(app) + + +def test_dependency_gen(): + response = client.get("/dependency-gen", params={"value": "test"}) + assert response.status_code == 200 + assert response.content == b'"response_get_dependency_gen"' + assert response.headers["X-Test"] == "test" + + +def test_dependency_async_gen(): + response = client.get("/dependency-async-gen", params={"value": "test"}) + assert response.status_code == 200 + assert response.content == b'"response_get_dependency_async_gen"' + assert response.headers["X-Test"] == "test" + + +def test_sub_dependency_gen(): + response = client.get("/sub-dependency-gen", params={"value": "test"}) + assert response.status_code == 201 + assert response.content == b'"response_get_sub_dependency_async_gen"' + assert response.headers["X-Test"] == "test" diff --git a/tests/test_tutorial/test_dependencies/test_tutorial008e.py b/tests/test_tutorial/test_dependencies/test_tutorial008e.py new file mode 100644 index 000000000..900e747da --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial008e.py @@ -0,0 +1,22 @@ +from fastapi.testclient import TestClient + +from docs_src.dependencies.tutorial008e import app + +client = TestClient(app) + + +def test_get_no_item(): + response = client.get("/items/foo") + assert response.status_code == 404, response.text + assert response.json() == {"detail": "Item not found"} + assert "X-Username" not in response.headers + + +def test_get(): + response = client.get("/items/plumbus") + assert response.status_code == 200, response.text + assert response.json() == { + "description": "Freshly pickled plumbus", + "owner": "Morty", + } + assert response.headers["X-Username"] == "Morty"