From 8155667c780384d4270238ce59c36c9ce4dde868 Mon Sep 17 00:00:00 2001 From: Peter Volf Date: Tue, 8 Oct 2024 10:19:10 +0200 Subject: [PATCH 1/3] fix handling of duplicate "special" dependencies --- fastapi/dependencies/models.py | 12 +++++----- fastapi/dependencies/utils.py | 43 ++++++++++++++++++---------------- 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 418c11725..62fd4cc89 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -22,12 +22,12 @@ class Dependant: security_requirements: List[SecurityRequirement] = field(default_factory=list) name: Optional[str] = None call: Optional[Callable[..., Any]] = None - request_param_name: Optional[str] = None - websocket_param_name: Optional[str] = None - http_connection_param_name: Optional[str] = None - response_param_name: Optional[str] = None - background_tasks_param_name: Optional[str] = None - security_scopes_param_name: Optional[str] = None + request_param_names: List[str] = field(default_factory=list) + websocket_param_names: List[str] = field(default_factory=list) + http_connection_param_names: List[str] = field(default_factory=list) + response_param_names: List[str] = field(default_factory=list) + background_tasks_param_names: List[str] = field(default_factory=list) + security_scopes_param_names: List[str] = field(default_factory=list) security_scopes: Optional[List[str]] = None use_cache: bool = True path: Optional[str] = None diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 5cebbf00f..2ff4ebf3e 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -310,22 +310,22 @@ def add_non_field_param_to_dependency( *, param_name: str, type_annotation: Any, dependant: Dependant ) -> Optional[bool]: if lenient_issubclass(type_annotation, Request): - dependant.request_param_name = param_name + dependant.request_param_names.append(param_name) return True elif lenient_issubclass(type_annotation, WebSocket): - dependant.websocket_param_name = param_name + dependant.websocket_param_names.append(param_name) return True elif lenient_issubclass(type_annotation, HTTPConnection): - dependant.http_connection_param_name = param_name + dependant.http_connection_param_names.append(param_name) return True elif lenient_issubclass(type_annotation, Response): - dependant.response_param_name = param_name + dependant.response_param_names.append(param_name) return True elif lenient_issubclass(type_annotation, StarletteBackgroundTasks): - dependant.background_tasks_param_name = param_name + dependant.background_tasks_param_names.append(param_name) return True elif lenient_issubclass(type_annotation, SecurityScopes): - dependant.security_scopes_param_name = param_name + dependant.security_scopes_param_names.append(param_name) return True return None @@ -660,22 +660,25 @@ async def solve_dependencies( ) values.update(body_values) errors.extend(body_errors) - if dependant.http_connection_param_name: - values[dependant.http_connection_param_name] = request - if dependant.request_param_name and isinstance(request, Request): - values[dependant.request_param_name] = request - elif dependant.websocket_param_name and isinstance(request, WebSocket): - values[dependant.websocket_param_name] = request - if dependant.background_tasks_param_name: + for name in dependant.http_connection_param_names: + values[name] = request + if isinstance(request, Request): + for name in dependant.request_param_names: + values[name] = request + elif isinstance(request, WebSocket): + for name in dependant.websocket_param_names: + values[name] = request + if dependant.background_tasks_param_names: if background_tasks is None: background_tasks = BackgroundTasks() - values[dependant.background_tasks_param_name] = background_tasks - if dependant.response_param_name: - values[dependant.response_param_name] = response - if dependant.security_scopes_param_name: - values[dependant.security_scopes_param_name] = SecurityScopes( - scopes=dependant.security_scopes - ) + for name in dependant.background_tasks_param_names: + values[name] = background_tasks + for name in dependant.response_param_names: + values[name] = response + if dependant.security_scopes_param_names: + security_scope = SecurityScopes(scopes=dependant.security_scopes) + for name in dependant.security_scopes_param_names: + values[name] = security_scope return SolvedDependency( values=values, errors=errors, From bda286314a911f5144f4f0f67d27b347506b8c47 Mon Sep 17 00:00:00 2001 From: Peter Volf Date: Tue, 8 Oct 2024 10:39:26 +0200 Subject: [PATCH 2/3] add tests for duplicate special dependency bug fix --- tests/test_duplicate_special_dependencies.py | 62 ++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 tests/test_duplicate_special_dependencies.py diff --git a/tests/test_duplicate_special_dependencies.py b/tests/test_duplicate_special_dependencies.py new file mode 100644 index 000000000..82944da1f --- /dev/null +++ b/tests/test_duplicate_special_dependencies.py @@ -0,0 +1,62 @@ +import pytest +from fastapi import BackgroundTasks, FastAPI, Request, Response, WebSocket +from fastapi.security import SecurityScopes +from fastapi.testclient import TestClient + +app = FastAPI() + + +@app.get("/request") +def request(r1: Request, r2: Request) -> str: + assert r1 is not None + assert r1 is r2 + return "success" + + +@app.get("/response") +def response(r1: Response, r2: Response) -> str: + assert r1 is not None + assert r1 is r2 + return "success" + + +@app.get("/background-tasks") +def background_tasks(t1: BackgroundTasks, t2: BackgroundTasks) -> str: + assert t1 is not None + assert t1 is t2 + return "success" + + +@app.get("/websocket") +def websocket(ws1: WebSocket, ws2: WebSocket) -> str: + assert ws1 is not None + assert ws1 is ws2 + return "success" + + +@app.get("/security-scopes") +def security_scopes(sc1: SecurityScopes, sc2: SecurityScopes) -> str: + assert sc1 is not None + assert sc1 is sc2 + return "success" + + +client = TestClient(app) + + +@pytest.mark.parametrize( + "url", + ( + "/request", + "/response", + "/background-tasks", + "/security-scopes", + ), +) +def test_duplicate_special_dependency(url: str) -> None: + assert client.get(url).text == '"success"' + + +def test_duplicate_websocket_dependency() -> None: + # Raises exception if connect fails. + client.websocket_connect("/websocket") From cca673514c33fc7bc69af42570c1b64b1d540dd4 Mon Sep 17 00:00:00 2001 From: Peter Volf Date: Tue, 8 Oct 2024 11:08:14 +0200 Subject: [PATCH 3/3] fix duplicate websocket dependency test --- tests/test_duplicate_special_dependencies.py | 21 ++++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/test_duplicate_special_dependencies.py b/tests/test_duplicate_special_dependencies.py index 82944da1f..8dfb57cb2 100644 --- a/tests/test_duplicate_special_dependencies.py +++ b/tests/test_duplicate_special_dependencies.py @@ -27,13 +27,6 @@ def background_tasks(t1: BackgroundTasks, t2: BackgroundTasks) -> str: return "success" -@app.get("/websocket") -def websocket(ws1: WebSocket, ws2: WebSocket) -> str: - assert ws1 is not None - assert ws1 is ws2 - return "success" - - @app.get("/security-scopes") def security_scopes(sc1: SecurityScopes, sc2: SecurityScopes) -> str: assert sc1 is not None @@ -41,7 +34,12 @@ def security_scopes(sc1: SecurityScopes, sc2: SecurityScopes) -> str: return "success" -client = TestClient(app) +@app.websocket("/websocket") +async def websocket(ws1: WebSocket, ws2: WebSocket) -> str: + assert ws1 is ws2 + await ws1.accept() + await ws1.send_text("success") + await ws1.close() @pytest.mark.parametrize( @@ -54,9 +52,10 @@ client = TestClient(app) ), ) def test_duplicate_special_dependency(url: str) -> None: - assert client.get(url).text == '"success"' + assert TestClient(app).get(url).text == '"success"' def test_duplicate_websocket_dependency() -> None: - # Raises exception if connect fails. - client.websocket_connect("/websocket") + with TestClient(app).websocket_connect("/websocket") as ws: + text = ws.receive_text() + assert text == "success"