From bff5dbbf5d95812e129fb2ebfa6323879b33644e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 5 Jun 2019 21:00:54 +0400 Subject: [PATCH] :sparkles: Implement dependency value cache per request (#292) * :sparkles: Add dependency cache, with support for disabling it * :white_check_mark: Add tests for dependency cache * :memo: Add docs about dependency value caching --- docs/tutorial/dependencies/first-steps.md | 3 - .../tutorial/dependencies/sub-dependencies.md | 15 +++- fastapi/dependencies/models.py | 4 ++ fastapi/dependencies/utils.py | 49 ++++++++++--- fastapi/param_functions.py | 10 +-- fastapi/params.py | 13 +++- fastapi/routing.py | 4 +- tests/test_dependency_cache.py | 68 +++++++++++++++++++ 8 files changed, 142 insertions(+), 24 deletions(-) create mode 100644 tests/test_dependency_cache.py diff --git a/docs/tutorial/dependencies/first-steps.md b/docs/tutorial/dependencies/first-steps.md index 7a19618a3..601fa6245 100644 --- a/docs/tutorial/dependencies/first-steps.md +++ b/docs/tutorial/dependencies/first-steps.md @@ -17,14 +17,12 @@ This is very useful when you need to: All these, while minimizing code repetition. - ## First Steps Let's see a very simple example. It will be so simple that it is not very useful, for now. But this way we can focus on how the **Dependency Injection** system works. - ### Create a dependency, or "dependable" Let's first focus on the dependency. @@ -151,7 +149,6 @@ The simplicity of the dependency injection system makes **FastAPI** compatible w * response data injection systems * etc. - ## Simple and Powerful Although the hierarchical dependency injection system is very simple to define and use, it's still very powerful. diff --git a/docs/tutorial/dependencies/sub-dependencies.md b/docs/tutorial/dependencies/sub-dependencies.md index 7f96674f3..e55dd14e4 100644 --- a/docs/tutorial/dependencies/sub-dependencies.md +++ b/docs/tutorial/dependencies/sub-dependencies.md @@ -11,6 +11,7 @@ You could create a first dependency ("dependable") like: ```Python hl_lines="6 7" {!./src/dependencies/tutorial005.py!} ``` + It declares an optional query parameter `q` as a `str`, and then it just returns it. This is quite simple (not very useful), but will help us focus on how the sub-dependencies work. @@ -43,6 +44,18 @@ Then we can use the dependency with: But **FastAPI** will know that it has to solve `query_extractor` first, to pass the results of that to `query_or_cookie_extractor` while calling it. +## Using the same dependency multiple times + +If one of your dependencies is declared multiple times for the same *path operation*, for example, multiple dependencies have a common sub-dependency, **FastAPI** will know to call that sub-dependency only once per request. + +And it will save the returned value in a "cache" and pass it to all the "dependants" that need it in that specific request, instead of calling the dependency multiple times for the same request. + +In an advanced scenario where you know you need the dependency to be called at every step (possibly multiple times) in the same request instead of using the "cached" value, you can set the parameter `use_cache=False` when using `Depends`: + +```Python hl_lines="1" +async def needy_dependency(fresh_value: str = Depends(get_value, use_cache=False)): + return {"fresh_value": fresh_value} +``` ## Recap @@ -54,7 +67,7 @@ But still, it is very powerful, and allows you to declare arbitrarily deeply nes !!! tip All this might not seem as useful with these simple examples. - + But you will see how useful it is in the chapters about **security**. And you will also see the amounts of code it will save you. diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 33644d764..29fdd0e22 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -30,6 +30,7 @@ class Dependant: background_tasks_param_name: str = None, security_scopes_param_name: str = None, security_scopes: List[str] = None, + use_cache: bool = True, path: str = None, ) -> None: self.path_params = path_params or [] @@ -46,5 +47,8 @@ class Dependant: self.security_scopes_param_name = security_scopes_param_name self.name = name self.call = call + self.use_cache = use_cache # Store the path to be able to re-generate a dependable from it in overrides self.path = path + # Save the cache key at creation to optimize performance + self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 2a64172ef..e79a9a6a0 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -95,7 +95,11 @@ def get_sub_dependant( security_scheme=dependency, scopes=use_scopes ) sub_dependant = get_dependant( - path=path, call=dependency, name=name, security_scopes=security_scopes + path=path, + call=dependency, + name=name, + security_scopes=security_scopes, + use_cache=depends.use_cache, ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) @@ -111,6 +115,7 @@ def get_flat_dependant(dependant: Dependant) -> Dependant: cookie_params=dependant.cookie_params.copy(), body_params=dependant.body_params.copy(), security_schemes=dependant.security_requirements.copy(), + use_cache=dependant.use_cache, path=dependant.path, ) for sub_dependant in dependant.dependencies: @@ -148,12 +153,17 @@ def is_scalar_sequence_field(field: Field) -> bool: def get_dependant( - *, path: str, call: Callable, name: str = None, security_scopes: List[str] = None + *, + path: str, + call: Callable, + name: str = None, + security_scopes: List[str] = None, + use_cache: bool = True, ) -> Dependant: path_param_names = get_path_param_names(path) endpoint_signature = inspect.signature(call) signature_params = endpoint_signature.parameters - dependant = Dependant(call=call, name=name, path=path) + dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache) for param_name, param in signature_params.items(): if isinstance(param.default, params.Depends): sub_dependant = get_param_sub_dependant( @@ -286,18 +296,29 @@ async def solve_dependencies( body: Dict[str, Any] = None, background_tasks: BackgroundTasks = None, dependency_overrides_provider: Any = None, -) -> Tuple[Dict[str, Any], List[ErrorWrapper], Optional[BackgroundTasks]]: + dependency_cache: Dict[Tuple[Callable, Tuple[str]], Any] = None, +) -> Tuple[ + Dict[str, Any], + List[ErrorWrapper], + Optional[BackgroundTasks], + Dict[Tuple[Callable, Tuple[str]], Any], +]: values: Dict[str, Any] = {} errors: List[ErrorWrapper] = [] + dependency_cache = dependency_cache or {} sub_dependant: Dependant for sub_dependant in dependant.dependencies: - call: Callable = sub_dependant.call # type: ignore + sub_dependant.call = cast(Callable, sub_dependant.call) + sub_dependant.cache_key = cast( + Tuple[Callable, Tuple[str]], sub_dependant.cache_key + ) + call = sub_dependant.call use_sub_dependant = sub_dependant if ( dependency_overrides_provider and dependency_overrides_provider.dependency_overrides ): - original_call: Callable = sub_dependant.call # type: ignore + original_call = sub_dependant.call call = getattr( dependency_overrides_provider, "dependency_overrides", {} ).get(original_call, original_call) @@ -309,22 +330,28 @@ async def solve_dependencies( security_scopes=sub_dependant.security_scopes, ) - sub_values, sub_errors, background_tasks = await solve_dependencies( + sub_values, sub_errors, background_tasks, sub_dependency_cache = await solve_dependencies( request=request, dependant=use_sub_dependant, body=body, background_tasks=background_tasks, dependency_overrides_provider=dependency_overrides_provider, + dependency_cache=dependency_cache, ) + dependency_cache.update(sub_dependency_cache) if sub_errors: errors.extend(sub_errors) continue - if is_coroutine_callable(call): + if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: + solved = dependency_cache[sub_dependant.cache_key] + elif is_coroutine_callable(call): solved = await call(**sub_values) else: solved = await run_in_threadpool(call, **sub_values) - if use_sub_dependant.name is not None: - values[use_sub_dependant.name] = solved + if sub_dependant.name is not None: + values[sub_dependant.name] = solved + if sub_dependant.cache_key not in dependency_cache: + dependency_cache[sub_dependant.cache_key] = solved path_values, path_errors = request_params_to_args( dependant.path_params, request.path_params ) @@ -360,7 +387,7 @@ async def solve_dependencies( values[dependant.security_scopes_param_name] = SecurityScopes( scopes=dependant.security_scopes ) - return values, errors, background_tasks + return values, errors, background_tasks, dependency_cache def request_params_to_args( diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index 92c83ba9a..abd95609c 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -238,11 +238,13 @@ def File( # noqa: N802 ) -def Depends(dependency: Callable = None) -> Any: # noqa: N802 - return params.Depends(dependency=dependency) +def Depends( # noqa: N802 + dependency: Callable = None, *, use_cache: bool = True +) -> Any: + return params.Depends(dependency=dependency, use_cache=use_cache) def Security( # noqa: N802 - dependency: Callable = None, scopes: Sequence[str] = None + dependency: Callable = None, *, scopes: Sequence[str] = None, use_cache: bool = True ) -> Any: - return params.Security(dependency=dependency, scopes=scopes) + return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache) diff --git a/fastapi/params.py b/fastapi/params.py index 3d9afec78..0541a3695 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -308,11 +308,18 @@ class File(Form): class Depends: - def __init__(self, dependency: Callable = None): + def __init__(self, dependency: Callable = None, *, use_cache: bool = True): self.dependency = dependency + self.use_cache = use_cache class Security(Depends): - def __init__(self, dependency: Callable = None, scopes: Sequence[str] = None): + def __init__( + self, + dependency: Callable = None, + *, + scopes: Sequence[str] = None, + use_cache: bool = True, + ): + super().__init__(dependency=dependency, use_cache=use_cache) self.scopes = scopes or [] - super().__init__(dependency=dependency) diff --git a/fastapi/routing.py b/fastapi/routing.py index 8526d8c04..4ae8bb586 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -102,7 +102,7 @@ def get_app( raise HTTPException( status_code=400, detail="There was an error parsing the body" ) from e - values, errors, background_tasks = await solve_dependencies( + values, errors, background_tasks, _ = await solve_dependencies( request=request, dependant=dependant, body=body, @@ -141,7 +141,7 @@ def get_websocket_app( dependant: Dependant, dependency_overrides_provider: Any = None ) -> Callable: async def app(websocket: WebSocket) -> None: - values, errors, _ = await solve_dependencies( + values, errors, _, _2 = await solve_dependencies( request=websocket, dependant=dependant, dependency_overrides_provider=dependency_overrides_provider, diff --git a/tests/test_dependency_cache.py b/tests/test_dependency_cache.py new file mode 100644 index 000000000..e9d027b1d --- /dev/null +++ b/tests/test_dependency_cache.py @@ -0,0 +1,68 @@ +from fastapi import Depends, FastAPI +from starlette.testclient import TestClient + +app = FastAPI() + +counter_holder = {"counter": 0} + + +async def dep_counter(): + counter_holder["counter"] += 1 + return counter_holder["counter"] + + +async def super_dep(count: int = Depends(dep_counter)): + return count + + +@app.get("/counter/") +async def get_counter(count: int = Depends(dep_counter)): + return {"counter": count} + + +@app.get("/sub-counter/") +async def get_sub_counter( + subcount: int = Depends(super_dep), count: int = Depends(dep_counter) +): + return {"counter": count, "subcounter": subcount} + + +@app.get("/sub-counter-no-cache/") +async def get_sub_counter_no_cache( + subcount: int = Depends(super_dep), + count: int = Depends(dep_counter, use_cache=False), +): + return {"counter": count, "subcounter": subcount} + + +client = TestClient(app) + + +def test_normal_counter(): + counter_holder["counter"] = 0 + response = client.get("/counter/") + assert response.status_code == 200 + assert response.json() == {"counter": 1} + response = client.get("/counter/") + assert response.status_code == 200 + assert response.json() == {"counter": 2} + + +def test_sub_counter(): + counter_holder["counter"] = 0 + response = client.get("/sub-counter/") + assert response.status_code == 200 + assert response.json() == {"counter": 1, "subcounter": 1} + response = client.get("/sub-counter/") + assert response.status_code == 200 + assert response.json() == {"counter": 2, "subcounter": 2} + + +def test_sub_counter_no_cache(): + counter_holder["counter"] = 0 + response = client.get("/sub-counter-no-cache/") + assert response.status_code == 200 + assert response.json() == {"counter": 2, "subcounter": 1} + response = client.get("/sub-counter-no-cache/") + assert response.status_code == 200 + assert response.json() == {"counter": 4, "subcounter": 3}