Browse Source

Implement dependency value cache per request (#292)

*  Add dependency cache, with support for disabling it

*  Add tests for dependency cache

* 📝 Add docs about dependency value caching
pull/294/head
Sebastián Ramírez 6 years ago
committed by GitHub
parent
commit
bff5dbbf5d
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      docs/tutorial/dependencies/first-steps.md
  2. 15
      docs/tutorial/dependencies/sub-dependencies.md
  3. 4
      fastapi/dependencies/models.py
  4. 49
      fastapi/dependencies/utils.py
  5. 10
      fastapi/param_functions.py
  6. 13
      fastapi/params.py
  7. 4
      fastapi/routing.py
  8. 68
      tests/test_dependency_cache.py

3
docs/tutorial/dependencies/first-steps.md

@ -17,14 +17,12 @@ This is very useful when you need to:
All these, while minimizing code repetition. All these, while minimizing code repetition.
## First Steps ## First Steps
Let's see a very simple example. It will be so simple that it is not very useful, for now. Let's see a very simple example. It will be so simple that it is not very useful, for now.
But this way we can focus on how the **Dependency Injection** system works. But this way we can focus on how the **Dependency Injection** system works.
### Create a dependency, or "dependable" ### Create a dependency, or "dependable"
Let's first focus on the dependency. Let's first focus on the dependency.
@ -151,7 +149,6 @@ The simplicity of the dependency injection system makes **FastAPI** compatible w
* response data injection systems * response data injection systems
* etc. * etc.
## Simple and Powerful ## Simple and Powerful
Although the hierarchical dependency injection system is very simple to define and use, it's still very powerful. Although the hierarchical dependency injection system is very simple to define and use, it's still very powerful.

15
docs/tutorial/dependencies/sub-dependencies.md

@ -11,6 +11,7 @@ You could create a first dependency ("dependable") like:
```Python hl_lines="6 7" ```Python hl_lines="6 7"
{!./src/dependencies/tutorial005.py!} {!./src/dependencies/tutorial005.py!}
``` ```
It declares an optional query parameter `q` as a `str`, and then it just returns it. It declares an optional query parameter `q` as a `str`, and then it just returns it.
This is quite simple (not very useful), but will help us focus on how the sub-dependencies work. This is quite simple (not very useful), but will help us focus on how the sub-dependencies work.
@ -43,6 +44,18 @@ Then we can use the dependency with:
But **FastAPI** will know that it has to solve `query_extractor` first, to pass the results of that to `query_or_cookie_extractor` while calling it. But **FastAPI** will know that it has to solve `query_extractor` first, to pass the results of that to `query_or_cookie_extractor` while calling it.
## Using the same dependency multiple times
If one of your dependencies is declared multiple times for the same *path operation*, for example, multiple dependencies have a common sub-dependency, **FastAPI** will know to call that sub-dependency only once per request.
And it will save the returned value in a <abbr title="A utility/system to store computed/generated values, to re-use them instead of computing them again.">"cache"</abbr> and pass it to all the "dependants" that need it in that specific request, instead of calling the dependency multiple times for the same request.
In an advanced scenario where you know you need the dependency to be called at every step (possibly multiple times) in the same request instead of using the "cached" value, you can set the parameter `use_cache=False` when using `Depends`:
```Python hl_lines="1"
async def needy_dependency(fresh_value: str = Depends(get_value, use_cache=False)):
return {"fresh_value": fresh_value}
```
## Recap ## Recap
@ -54,7 +67,7 @@ But still, it is very powerful, and allows you to declare arbitrarily deeply nes
!!! tip !!! tip
All this might not seem as useful with these simple examples. All this might not seem as useful with these simple examples.
But you will see how useful it is in the chapters about **security**. But you will see how useful it is in the chapters about **security**.
And you will also see the amounts of code it will save you. And you will also see the amounts of code it will save you.

4
fastapi/dependencies/models.py

@ -30,6 +30,7 @@ class Dependant:
background_tasks_param_name: str = None, background_tasks_param_name: str = None,
security_scopes_param_name: str = None, security_scopes_param_name: str = None,
security_scopes: List[str] = None, security_scopes: List[str] = None,
use_cache: bool = True,
path: str = None, path: str = None,
) -> None: ) -> None:
self.path_params = path_params or [] self.path_params = path_params or []
@ -46,5 +47,8 @@ class Dependant:
self.security_scopes_param_name = security_scopes_param_name self.security_scopes_param_name = security_scopes_param_name
self.name = name self.name = name
self.call = call self.call = call
self.use_cache = use_cache
# Store the path to be able to re-generate a dependable from it in overrides # Store the path to be able to re-generate a dependable from it in overrides
self.path = path self.path = path
# Save the cache key at creation to optimize performance
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))

49
fastapi/dependencies/utils.py

@ -95,7 +95,11 @@ def get_sub_dependant(
security_scheme=dependency, scopes=use_scopes security_scheme=dependency, scopes=use_scopes
) )
sub_dependant = get_dependant( sub_dependant = get_dependant(
path=path, call=dependency, name=name, security_scopes=security_scopes path=path,
call=dependency,
name=name,
security_scopes=security_scopes,
use_cache=depends.use_cache,
) )
if security_requirement: if security_requirement:
sub_dependant.security_requirements.append(security_requirement) sub_dependant.security_requirements.append(security_requirement)
@ -111,6 +115,7 @@ def get_flat_dependant(dependant: Dependant) -> Dependant:
cookie_params=dependant.cookie_params.copy(), cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(), body_params=dependant.body_params.copy(),
security_schemes=dependant.security_requirements.copy(), security_schemes=dependant.security_requirements.copy(),
use_cache=dependant.use_cache,
path=dependant.path, path=dependant.path,
) )
for sub_dependant in dependant.dependencies: for sub_dependant in dependant.dependencies:
@ -148,12 +153,17 @@ def is_scalar_sequence_field(field: Field) -> bool:
def get_dependant( def get_dependant(
*, path: str, call: Callable, name: str = None, security_scopes: List[str] = None *,
path: str,
call: Callable,
name: str = None,
security_scopes: List[str] = None,
use_cache: bool = True,
) -> Dependant: ) -> Dependant:
path_param_names = get_path_param_names(path) path_param_names = get_path_param_names(path)
endpoint_signature = inspect.signature(call) endpoint_signature = inspect.signature(call)
signature_params = endpoint_signature.parameters signature_params = endpoint_signature.parameters
dependant = Dependant(call=call, name=name, path=path) dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
for param_name, param in signature_params.items(): for param_name, param in signature_params.items():
if isinstance(param.default, params.Depends): if isinstance(param.default, params.Depends):
sub_dependant = get_param_sub_dependant( sub_dependant = get_param_sub_dependant(
@ -286,18 +296,29 @@ async def solve_dependencies(
body: Dict[str, Any] = None, body: Dict[str, Any] = None,
background_tasks: BackgroundTasks = None, background_tasks: BackgroundTasks = None,
dependency_overrides_provider: Any = None, dependency_overrides_provider: Any = None,
) -> Tuple[Dict[str, Any], List[ErrorWrapper], Optional[BackgroundTasks]]: dependency_cache: Dict[Tuple[Callable, Tuple[str]], Any] = None,
) -> Tuple[
Dict[str, Any],
List[ErrorWrapper],
Optional[BackgroundTasks],
Dict[Tuple[Callable, Tuple[str]], Any],
]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = [] errors: List[ErrorWrapper] = []
dependency_cache = dependency_cache or {}
sub_dependant: Dependant sub_dependant: Dependant
for sub_dependant in dependant.dependencies: for sub_dependant in dependant.dependencies:
call: Callable = sub_dependant.call # type: ignore sub_dependant.call = cast(Callable, sub_dependant.call)
sub_dependant.cache_key = cast(
Tuple[Callable, Tuple[str]], sub_dependant.cache_key
)
call = sub_dependant.call
use_sub_dependant = sub_dependant use_sub_dependant = sub_dependant
if ( if (
dependency_overrides_provider dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides and dependency_overrides_provider.dependency_overrides
): ):
original_call: Callable = sub_dependant.call # type: ignore original_call = sub_dependant.call
call = getattr( call = getattr(
dependency_overrides_provider, "dependency_overrides", {} dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call) ).get(original_call, original_call)
@ -309,22 +330,28 @@ async def solve_dependencies(
security_scopes=sub_dependant.security_scopes, security_scopes=sub_dependant.security_scopes,
) )
sub_values, sub_errors, background_tasks = await solve_dependencies( sub_values, sub_errors, background_tasks, sub_dependency_cache = await solve_dependencies(
request=request, request=request,
dependant=use_sub_dependant, dependant=use_sub_dependant,
body=body, body=body,
background_tasks=background_tasks, background_tasks=background_tasks,
dependency_overrides_provider=dependency_overrides_provider, dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
) )
dependency_cache.update(sub_dependency_cache)
if sub_errors: if sub_errors:
errors.extend(sub_errors) errors.extend(sub_errors)
continue continue
if is_coroutine_callable(call): if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
elif is_coroutine_callable(call):
solved = await call(**sub_values) solved = await call(**sub_values)
else: else:
solved = await run_in_threadpool(call, **sub_values) solved = await run_in_threadpool(call, **sub_values)
if use_sub_dependant.name is not None: if sub_dependant.name is not None:
values[use_sub_dependant.name] = solved values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:
dependency_cache[sub_dependant.cache_key] = solved
path_values, path_errors = request_params_to_args( path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params dependant.path_params, request.path_params
) )
@ -360,7 +387,7 @@ async def solve_dependencies(
values[dependant.security_scopes_param_name] = SecurityScopes( values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes scopes=dependant.security_scopes
) )
return values, errors, background_tasks return values, errors, background_tasks, dependency_cache
def request_params_to_args( def request_params_to_args(

10
fastapi/param_functions.py

@ -238,11 +238,13 @@ def File( # noqa: N802
) )
def Depends(dependency: Callable = None) -> Any: # noqa: N802 def Depends( # noqa: N802
return params.Depends(dependency=dependency) dependency: Callable = None, *, use_cache: bool = True
) -> Any:
return params.Depends(dependency=dependency, use_cache=use_cache)
def Security( # noqa: N802 def Security( # noqa: N802
dependency: Callable = None, scopes: Sequence[str] = None dependency: Callable = None, *, scopes: Sequence[str] = None, use_cache: bool = True
) -> Any: ) -> Any:
return params.Security(dependency=dependency, scopes=scopes) return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache)

13
fastapi/params.py

@ -308,11 +308,18 @@ class File(Form):
class Depends: class Depends:
def __init__(self, dependency: Callable = None): def __init__(self, dependency: Callable = None, *, use_cache: bool = True):
self.dependency = dependency self.dependency = dependency
self.use_cache = use_cache
class Security(Depends): class Security(Depends):
def __init__(self, dependency: Callable = None, scopes: Sequence[str] = None): def __init__(
self,
dependency: Callable = None,
*,
scopes: Sequence[str] = None,
use_cache: bool = True,
):
super().__init__(dependency=dependency, use_cache=use_cache)
self.scopes = scopes or [] self.scopes = scopes or []
super().__init__(dependency=dependency)

4
fastapi/routing.py

@ -102,7 +102,7 @@ def get_app(
raise HTTPException( raise HTTPException(
status_code=400, detail="There was an error parsing the body" status_code=400, detail="There was an error parsing the body"
) from e ) from e
values, errors, background_tasks = await solve_dependencies( values, errors, background_tasks, _ = await solve_dependencies(
request=request, request=request,
dependant=dependant, dependant=dependant,
body=body, body=body,
@ -141,7 +141,7 @@ def get_websocket_app(
dependant: Dependant, dependency_overrides_provider: Any = None dependant: Dependant, dependency_overrides_provider: Any = None
) -> Callable: ) -> Callable:
async def app(websocket: WebSocket) -> None: async def app(websocket: WebSocket) -> None:
values, errors, _ = await solve_dependencies( values, errors, _, _2 = await solve_dependencies(
request=websocket, request=websocket,
dependant=dependant, dependant=dependant,
dependency_overrides_provider=dependency_overrides_provider, dependency_overrides_provider=dependency_overrides_provider,

68
tests/test_dependency_cache.py

@ -0,0 +1,68 @@
from fastapi import Depends, FastAPI
from starlette.testclient import TestClient
app = FastAPI()
counter_holder = {"counter": 0}
async def dep_counter():
counter_holder["counter"] += 1
return counter_holder["counter"]
async def super_dep(count: int = Depends(dep_counter)):
return count
@app.get("/counter/")
async def get_counter(count: int = Depends(dep_counter)):
return {"counter": count}
@app.get("/sub-counter/")
async def get_sub_counter(
subcount: int = Depends(super_dep), count: int = Depends(dep_counter)
):
return {"counter": count, "subcounter": subcount}
@app.get("/sub-counter-no-cache/")
async def get_sub_counter_no_cache(
subcount: int = Depends(super_dep),
count: int = Depends(dep_counter, use_cache=False),
):
return {"counter": count, "subcounter": subcount}
client = TestClient(app)
def test_normal_counter():
counter_holder["counter"] = 0
response = client.get("/counter/")
assert response.status_code == 200
assert response.json() == {"counter": 1}
response = client.get("/counter/")
assert response.status_code == 200
assert response.json() == {"counter": 2}
def test_sub_counter():
counter_holder["counter"] = 0
response = client.get("/sub-counter/")
assert response.status_code == 200
assert response.json() == {"counter": 1, "subcounter": 1}
response = client.get("/sub-counter/")
assert response.status_code == 200
assert response.json() == {"counter": 2, "subcounter": 2}
def test_sub_counter_no_cache():
counter_holder["counter"] = 0
response = client.get("/sub-counter-no-cache/")
assert response.status_code == 200
assert response.json() == {"counter": 2, "subcounter": 1}
response = client.get("/sub-counter-no-cache/")
assert response.status_code == 200
assert response.json() == {"counter": 4, "subcounter": 3}
Loading…
Cancel
Save