Browse Source

Merge 9338618f5f into 6df50d40fe

pull/13607/merge
Aleksandr Sulimov 2 days ago
committed by GitHub
parent
commit
ca430318b0
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 17
      docs/en/docs/tutorial/dependencies/dependencies-with-yield.md
  2. 23
      docs_src/dependencies/tutorial008e.py
  3. 201
      fastapi/concurrency.py
  4. 23
      fastapi/dependencies/utils.py
  5. 3
      fastapi/routing.py
  6. 73
      tests/test_dependency_response.py
  7. 22
      tests/test_tutorial/test_dependencies/test_tutorial008e.py

17
docs/en/docs/tutorial/dependencies/dependencies-with-yield.md

@ -19,7 +19,7 @@ Any function that is valid to use with:
would be valid to use as a **FastAPI** dependency. would be valid to use as a **FastAPI** dependency.
In fact, FastAPI uses those two decorators internally. In fact, FastAPI uses similar decorator internally.
/// ///
@ -129,6 +129,19 @@ You can re-raise the same exception using `raise`:
Now the client will get the same *HTTP 500 Internal Server Error* response, but the server will have our custom `InternalError` in the logs. 😎 Now the client will get the same *HTTP 500 Internal Server Error* response, but the server will have our custom `InternalError` in the logs. 😎
## Modifying response in dependencies after `yield`
If you want to read or change response in a dependency with `yield` after `yield`, you can get your endpoint's response
if you assign `yield` to it:
{* ../../docs_src/dependencies/tutorial008e.py hl[13:14] *}
/// warning
You cannot use dependency-injected `Response` to modify response in dependency after `yield`.
///
## Execution of dependencies with `yield` ## Execution of dependencies with `yield`
The sequence of execution is more or less like this diagram. Time flows from top to bottom. And each column is one of the parts interacting or executing code. The sequence of execution is more or less like this diagram. Time flows from top to bottom. And each column is one of the parts interacting or executing code.
@ -266,7 +279,7 @@ Another way to create a context manager is with:
using them to decorate a function with a single `yield`. using them to decorate a function with a single `yield`.
That's what **FastAPI** uses internally for dependencies with `yield`. A similar decorator is what **FastAPI** uses internally for dependencies with `yield`.
But you don't have to use the decorators for FastAPI dependencies (and you shouldn't). But you don't have to use the decorators for FastAPI dependencies (and you shouldn't).

23
docs_src/dependencies/tutorial008e.py

@ -0,0 +1,23 @@
import json
from fastapi import Depends, FastAPI, HTTPException
app = FastAPI()
data = {
"plumbus": {"description": "Freshly pickled plumbus", "owner": "Morty"},
"portal-gun": {"description": "Gun to create portals", "owner": "Rick"},
}
def set_username():
response = yield
response.headers["X-Username"] = json.loads(response.body)["owner"]
@app.get("/items/{item_id}", dependencies=[Depends(set_username)])
def get_item(item_id: str):
if item_id not in data:
raise HTTPException(status_code=404, detail="Item not found")
return data[item_id]

201
fastapi/concurrency.py

@ -1,5 +1,15 @@
from contextlib import asynccontextmanager as asynccontextmanager from contextlib import asynccontextmanager as asynccontextmanager
from typing import AsyncGenerator, ContextManager, TypeVar from types import TracebackType
from typing import (
Any,
AsyncGenerator,
ContextManager,
Generator,
Optional,
Type,
TypeVar,
Union,
)
import anyio.to_thread import anyio.to_thread
from anyio import CapacityLimiter from anyio import CapacityLimiter
@ -13,9 +23,9 @@ _T = TypeVar("_T")
@asynccontextmanager @asynccontextmanager
async def contextmanager_in_threadpool( async def contextmanager_in_threadpool( # not used, kept for backwards compatibility
cm: ContextManager[_T], cm: ContextManager[_T],
) -> AsyncGenerator[_T, None]: ) -> AsyncGenerator[_T, None]: # pragma: no cover
# blocking __exit__ from running waiting on a free thread # blocking __exit__ from running waiting on a free thread
# can create race conditions/deadlocks if the context manager itself # can create race conditions/deadlocks if the context manager itself
# has its own internal pool (e.g. a database connection pool) # has its own internal pool (e.g. a database connection pool)
@ -37,3 +47,188 @@ async def contextmanager_in_threadpool(
await anyio.to_thread.run_sync( await anyio.to_thread.run_sync(
cm.__exit__, None, None, None, limiter=exit_limiter cm.__exit__, None, None, None, limiter=exit_limiter
) )
class _StopIteration(Exception):
pass
class ContextManagerFromGenerator:
"""Create a context manager from a generator.
It handles both sync and async generators. Generator has to have exactly one yield.
Implementation is based on contextlib's contextmanager.
Additionally, as apposed to contextlib.contextmanager/contextlib.asynccontextmanager,
this context manager allows to call `asend` on underlaying generator and gracefully handle
this scenario in __aexit__.
Instances of this class cannot be reused.
"""
def __init__(self, gen: Union[AsyncGenerator[_T, None], Generator[_T, None, None]]):
self.gen = gen
self._has_started = False
self._has_executed = False
@staticmethod
def _send(gen: Generator[_T, None, None], value: Any) -> Any:
# We can't raise `StopIteration` from within the threadpool executor
# and catch it outside that context, so we coerce them into a different
# exception type.
try:
return gen.send(value)
except StopIteration:
raise _StopIteration from None
@staticmethod
def _throw(gen: Generator[_T, None, None], value: Any) -> Any:
# We can't raise `StopIteration` from within the threadpool executor
# and catch it outside that context, so we coerce them into a different
# exception type.
try:
return gen.throw(value)
except StopIteration:
raise _StopIteration from None
async def __aenter__(self) -> Any:
try:
if isinstance(self.gen, Generator):
result = await run_in_threadpool(self._send, self.gen, None)
else:
result = await self.gen.asend(None)
self._has_started = True
return result
except (_StopIteration, StopAsyncIteration):
raise RuntimeError("generator didn't yield") from None # pragma: no cover
async def asend(self, value: Any) -> None:
if self._has_executed:
raise RuntimeError(
"ContextManagerFromGenerator can only be used once"
) # pragma: no cover
if not self._has_started:
raise RuntimeError(
"ContextManagerFromGenerator has not been entered"
) # pragma: no cover
self._has_executed = True
try:
if isinstance(self.gen, Generator):
await run_in_threadpool(self._send, self.gen, value)
else:
await self.gen.asend(value)
except (_StopIteration, StopAsyncIteration):
return
else: # pragma: no cover
try:
raise RuntimeError("generator didn't stop")
finally:
if isinstance(self.gen, Generator):
self.gen.close()
else:
await self.gen.aclose()
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> bool:
if self._has_executed:
if isinstance(self.gen, Generator):
self.gen.close()
return False
await self.gen.aclose()
return False
# blocking __exit__ from running waiting on a free thread
# can create race conditions/deadlocks if the context manager itself
# has its own internal pool (e.g. a database connection pool)
# to avoid this we let __exit__ run without a capacity limit
# since we're creating a new limiter for each call, any non-zero limit
# works (1 is arbitrary)
exit_limiter = CapacityLimiter(1)
if exc_type is None: # pragma: no cover
# usually shouldn't happen, as we call asend
try:
if isinstance(self.gen, Generator):
await anyio.to_thread.run_sync(
self._send,
self.gen,
None,
limiter=exit_limiter,
)
else:
await self.gen.asend(None)
except (_StopIteration, StopAsyncIteration):
return False
else: # pragma: no cover
try:
raise RuntimeError("generator didn't stop")
finally:
if isinstance(self.gen, Generator):
self.gen.close()
else:
await self.gen.aclose()
if exc_value is None: # pragma: no cover
# Need to force instantiation so we can reliably
# tell if we get the same exception back
exc_value = exc_type()
try:
if isinstance(self.gen, Generator):
await anyio.to_thread.run_sync(
self._throw,
self.gen,
exc_value,
limiter=exit_limiter,
)
else:
await self.gen.athrow(exc_value)
except (StopIteration, _StopIteration, StopAsyncIteration) as exc:
# Suppress Stop(Async)Iteration *unless* it's the same exception that
# was passed to throw(). This prevents a Stop(Async)Iteration
# raised inside the "with" statement from being suppressed.
return exc is not exc_value
except RuntimeError as exc: # pragma: no cover
# Don't re-raise the passed in exception. (issue27122)
if exc is exc_value:
exc.__traceback__ = traceback
return False
# Avoid suppressing if a Stop(Async)Iteration exception
# was passed to athrow() and later wrapped into a RuntimeError
# (see PEP 479 for sync generators; async generators also
# have this behavior). But do this only if the exception wrapped
# by the RuntimeError is actually Stop(Async)Iteration (see
# issue29692).
if (
isinstance(
exc_value,
(StopIteration, _StopIteration, StopAsyncIteration),
)
and exc.__cause__ is exc_value
):
exc_value.__traceback__ = traceback
return False
raise
except BaseException as exc:
# only re-raise if it's *not* the exception that was
# passed to throw(), because __exit__() must not raise
# an exception unless __exit__() itself failed. But throw()
# has to raise the exception to signal propagation, so this
# fixes the impedance mismatch between the throw() protocol
# and the __exit__() protocol.
if exc is not exc_value:
raise
exc.__traceback__ = traceback
return False
try: # pragma: no cover
raise RuntimeError("generator didn't stop after athrow()")
finally: # pragma: no cover
if isinstance(self.gen, Generator):
self.gen.close()
else:
await self.gen.aclose()

23
fastapi/dependencies/utils.py

@ -1,5 +1,5 @@
import inspect import inspect
from contextlib import AsyncExitStack, contextmanager from contextlib import AsyncExitStack
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (
@ -47,10 +47,7 @@ from fastapi._compat import (
value_is_sequence, value_is_sequence,
) )
from fastapi.background import BackgroundTasks from fastapi.background import BackgroundTasks
from fastapi.concurrency import ( from fastapi.concurrency import ContextManagerFromGenerator
asynccontextmanager,
contextmanager_in_threadpool,
)
from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.logger import logger from fastapi.logger import logger
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
@ -552,12 +549,9 @@ def is_gen_callable(call: Callable[..., Any]) -> bool:
async def solve_generator( async def solve_generator(
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
) -> Any: ) -> Tuple[Any, Callable[[Response], Coroutine[None, None, None]]]:
if is_gen_callable(call): cm = ContextManagerFromGenerator(call(**sub_values))
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) return (await stack.enter_async_context(cm), cm.asend)
elif is_async_gen_callable(call):
cm = asynccontextmanager(call)(**sub_values)
return await stack.enter_async_context(cm)
@dataclass @dataclass
@ -567,6 +561,7 @@ class SolvedDependency:
background_tasks: Optional[StarletteBackgroundTasks] background_tasks: Optional[StarletteBackgroundTasks]
response: Response response: Response
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
generators_callbacks: List[Callable[[Response], Coroutine[None, None, None]]]
async def solve_dependencies( async def solve_dependencies(
@ -589,6 +584,7 @@ async def solve_dependencies(
response.status_code = None # type: ignore response.status_code = None # type: ignore
dependency_cache = dependency_cache or {} dependency_cache = dependency_cache or {}
sub_dependant: Dependant sub_dependant: Dependant
generators_callbacks: List[Callable[[Response], Coroutine[None, None, None]]] = []
for sub_dependant in dependant.dependencies: for sub_dependant in dependant.dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast( sub_dependant.cache_key = cast(
@ -625,15 +621,17 @@ async def solve_dependencies(
) )
background_tasks = solved_result.background_tasks background_tasks = solved_result.background_tasks
dependency_cache.update(solved_result.dependency_cache) dependency_cache.update(solved_result.dependency_cache)
generators_callbacks.extend(solved_result.generators_callbacks)
if solved_result.errors: if solved_result.errors:
errors.extend(solved_result.errors) errors.extend(solved_result.errors)
continue continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key] solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call): elif is_gen_callable(call) or is_async_gen_callable(call):
solved = await solve_generator( solved, callback = await solve_generator(
call=call, stack=async_exit_stack, sub_values=solved_result.values call=call, stack=async_exit_stack, sub_values=solved_result.values
) )
generators_callbacks.append(callback)
elif is_coroutine_callable(call): elif is_coroutine_callable(call):
solved = await call(**solved_result.values) solved = await call(**solved_result.values)
else: else:
@ -692,6 +690,7 @@ async def solve_dependencies(
background_tasks=background_tasks, background_tasks=background_tasks,
response=response, response=response,
dependency_cache=dependency_cache, dependency_cache=dependency_cache,
generators_callbacks=generators_callbacks,
) )

3
fastapi/routing.py

@ -340,6 +340,9 @@ def get_request_handler(
if not is_body_allowed_for_status_code(response.status_code): if not is_body_allowed_for_status_code(response.status_code):
response.body = b"" response.body = b""
response.headers.raw.extend(solved_result.response.headers.raw) response.headers.raw.extend(solved_result.response.headers.raw)
# pass response to generator dependencies
for callback in reversed(solved_result.generators_callbacks):
await callback(response)
if errors: if errors:
validation_error = RequestValidationError( validation_error = RequestValidationError(
_normalize_errors(errors), body=body _normalize_errors(errors), body=body

73
tests/test_dependency_response.py

@ -0,0 +1,73 @@
from typing import AsyncGenerator, Generator
from fastapi import Depends, FastAPI, Response
from fastapi.testclient import TestClient
app = FastAPI()
def dependency_gen(value: str) -> Generator[str, None, None]:
response: Response = yield
assert isinstance(response, Response)
assert response.status_code == 200
assert response.body == b'"response_get_dependency_gen"'
response.headers["X-Test"] = value
async def dependency_async_gen(value: str) -> AsyncGenerator[str, None]:
response = yield
assert isinstance(response, Response)
assert response.status_code == 200
assert response.body == b'"response_get_dependency_async_gen"'
response.headers["X-Test"] = value
async def sub_dependency_async_gen(value: str) -> AsyncGenerator[str, None]:
response = yield
assert isinstance(response, Response)
response.status_code = 201
assert response.body == b'"response_get_sub_dependency_async_gen"'
response.headers["X-Test"] = value
async def parent_dependency(result=Depends(sub_dependency_async_gen)):
return result
@app.get("/dependency-gen", dependencies=[Depends(dependency_gen)])
async def get_dependency_gen():
return "response_get_dependency_gen"
@app.get("/dependency-async-gen", dependencies=[Depends(dependency_async_gen)])
async def get_dependency_async_gen():
return "response_get_dependency_async_gen"
@app.get("/sub-dependency-gen", dependencies=[Depends(parent_dependency)])
async def get_sub_dependency_gen():
return "response_get_sub_dependency_async_gen"
client = TestClient(app)
def test_dependency_gen():
response = client.get("/dependency-gen", params={"value": "test"})
assert response.status_code == 200
assert response.content == b'"response_get_dependency_gen"'
assert response.headers["X-Test"] == "test"
def test_dependency_async_gen():
response = client.get("/dependency-async-gen", params={"value": "test"})
assert response.status_code == 200
assert response.content == b'"response_get_dependency_async_gen"'
assert response.headers["X-Test"] == "test"
def test_sub_dependency_gen():
response = client.get("/sub-dependency-gen", params={"value": "test"})
assert response.status_code == 201
assert response.content == b'"response_get_sub_dependency_async_gen"'
assert response.headers["X-Test"] == "test"

22
tests/test_tutorial/test_dependencies/test_tutorial008e.py

@ -0,0 +1,22 @@
from fastapi.testclient import TestClient
from docs_src.dependencies.tutorial008e import app
client = TestClient(app)
def test_get_no_item():
response = client.get("/items/foo")
assert response.status_code == 404, response.text
assert response.json() == {"detail": "Item not found"}
assert "X-Username" not in response.headers
def test_get():
response = client.get("/items/plumbus")
assert response.status_code == 200, response.text
assert response.json() == {
"description": "Freshly pickled plumbus",
"owner": "Morty",
}
assert response.headers["X-Username"] == "Morty"
Loading…
Cancel
Save