From b993b4af287e63904b675ba2ee1c232e86f2b072 Mon Sep 17 00:00:00 2001 From: laggardkernel Date: Tue, 23 Aug 2022 21:30:24 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20cached=20dependencies=20wh?= =?UTF-8?q?en=20using=20a=20dependency=20in=20`Security()`=20and=20other?= =?UTF-8?q?=20places=20(e.g.=20`Depends()`)=20with=20different=20OAuth2=20?= =?UTF-8?q?scopes=20(#2945)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sebastián Ramírez --- fastapi/dependencies/utils.py | 10 +++++++--- tests/test_dependency_cache.py | 25 ++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index f397e333c..f6151f6bd 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -161,7 +161,6 @@ def get_sub_dependant( ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) - sub_dependant.security_scopes = security_scopes return sub_dependant @@ -278,7 +277,13 @@ def get_dependant( path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters - dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache) + dependant = Dependant( + call=call, + name=name, + path=path, + security_scopes=security_scopes, + use_cache=use_cache, + ) for param_name, param in signature_params.items(): if isinstance(param.default, params.Depends): sub_dependant = get_param_sub_dependant( @@ -495,7 +500,6 @@ async def solve_dependencies( name=sub_dependant.name, security_scopes=sub_dependant.security_scopes, ) - use_sub_dependant.security_scopes = sub_dependant.security_scopes solved_result = await solve_dependencies( request=request, diff --git a/tests/test_dependency_cache.py b/tests/test_dependency_cache.py index 65ed7f946..08fb9b74f 100644 --- a/tests/test_dependency_cache.py +++ b/tests/test_dependency_cache.py @@ -1,4 +1,4 @@ -from fastapi import Depends, FastAPI +from fastapi import Depends, FastAPI, Security from fastapi.testclient import TestClient app = FastAPI() @@ -35,6 +35,19 @@ async def get_sub_counter_no_cache( return {"counter": count, "subcounter": subcount} +@app.get("/scope-counter") +async def get_scope_counter( + count: int = Security(dep_counter), + scope_count_1: int = Security(dep_counter, scopes=["scope"]), + scope_count_2: int = Security(dep_counter, scopes=["scope"]), +): + return { + "counter": count, + "scope_counter_1": scope_count_1, + "scope_counter_2": scope_count_2, + } + + client = TestClient(app) @@ -66,3 +79,13 @@ def test_sub_counter_no_cache(): response = client.get("/sub-counter-no-cache/") assert response.status_code == 200, response.text assert response.json() == {"counter": 4, "subcounter": 3} + + +def test_security_cache(): + counter_holder["counter"] = 0 + response = client.get("/scope-counter/") + assert response.status_code == 200, response.text + assert response.json() == {"counter": 1, "scope_counter_1": 2, "scope_counter_2": 2} + response = client.get("/scope-counter/") + assert response.status_code == 200, response.text + assert response.json() == {"counter": 3, "scope_counter_1": 4, "scope_counter_2": 4}