diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 418c11725..62fd4cc89 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -22,12 +22,12 @@ class Dependant: security_requirements: List[SecurityRequirement] = field(default_factory=list) name: Optional[str] = None call: Optional[Callable[..., Any]] = None - request_param_name: Optional[str] = None - websocket_param_name: Optional[str] = None - http_connection_param_name: Optional[str] = None - response_param_name: Optional[str] = None - background_tasks_param_name: Optional[str] = None - security_scopes_param_name: Optional[str] = None + request_param_names: List[str] = field(default_factory=list) + websocket_param_names: List[str] = field(default_factory=list) + http_connection_param_names: List[str] = field(default_factory=list) + response_param_names: List[str] = field(default_factory=list) + background_tasks_param_names: List[str] = field(default_factory=list) + security_scopes_param_names: List[str] = field(default_factory=list) security_scopes: Optional[List[str]] = None use_cache: bool = True path: Optional[str] = None diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 84dfa4d03..852ecced6 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -318,22 +318,22 @@ def add_non_field_param_to_dependency( *, param_name: str, type_annotation: Any, dependant: Dependant ) -> Optional[bool]: if lenient_issubclass(type_annotation, Request): - dependant.request_param_name = param_name + dependant.request_param_names.append(param_name) return True elif lenient_issubclass(type_annotation, WebSocket): - dependant.websocket_param_name = param_name + dependant.websocket_param_names.append(param_name) return True elif lenient_issubclass(type_annotation, HTTPConnection): - dependant.http_connection_param_name = param_name + dependant.http_connection_param_names.append(param_name) return True elif lenient_issubclass(type_annotation, Response): - dependant.response_param_name = param_name + dependant.response_param_names.append(param_name) return True elif lenient_issubclass(type_annotation, StarletteBackgroundTasks): - dependant.background_tasks_param_name = param_name + dependant.background_tasks_param_names.append(param_name) return True elif lenient_issubclass(type_annotation, SecurityScopes): - dependant.security_scopes_param_name = param_name + dependant.security_scopes_param_names.append(param_name) return True return None @@ -670,22 +670,25 @@ async def solve_dependencies( ) values.update(body_values) errors.extend(body_errors) - if dependant.http_connection_param_name: - values[dependant.http_connection_param_name] = request - if dependant.request_param_name and isinstance(request, Request): - values[dependant.request_param_name] = request - elif dependant.websocket_param_name and isinstance(request, WebSocket): - values[dependant.websocket_param_name] = request - if dependant.background_tasks_param_name: + for name in dependant.http_connection_param_names: + values[name] = request + if isinstance(request, Request): + for name in dependant.request_param_names: + values[name] = request + elif isinstance(request, WebSocket): + for name in dependant.websocket_param_names: + values[name] = request + if dependant.background_tasks_param_names: if background_tasks is None: background_tasks = BackgroundTasks() - values[dependant.background_tasks_param_name] = background_tasks - if dependant.response_param_name: - values[dependant.response_param_name] = response - if dependant.security_scopes_param_name: - values[dependant.security_scopes_param_name] = SecurityScopes( - scopes=dependant.security_scopes - ) + for name in dependant.background_tasks_param_names: + values[name] = background_tasks + for name in dependant.response_param_names: + values[name] = response + if dependant.security_scopes_param_names: + security_scope = SecurityScopes(scopes=dependant.security_scopes) + for name in dependant.security_scopes_param_names: + values[name] = security_scope return SolvedDependency( values=values, errors=errors, diff --git a/tests/test_duplicate_special_dependencies.py b/tests/test_duplicate_special_dependencies.py new file mode 100644 index 000000000..8dfb57cb2 --- /dev/null +++ b/tests/test_duplicate_special_dependencies.py @@ -0,0 +1,61 @@ +import pytest +from fastapi import BackgroundTasks, FastAPI, Request, Response, WebSocket +from fastapi.security import SecurityScopes +from fastapi.testclient import TestClient + +app = FastAPI() + + +@app.get("/request") +def request(r1: Request, r2: Request) -> str: + assert r1 is not None + assert r1 is r2 + return "success" + + +@app.get("/response") +def response(r1: Response, r2: Response) -> str: + assert r1 is not None + assert r1 is r2 + return "success" + + +@app.get("/background-tasks") +def background_tasks(t1: BackgroundTasks, t2: BackgroundTasks) -> str: + assert t1 is not None + assert t1 is t2 + return "success" + + +@app.get("/security-scopes") +def security_scopes(sc1: SecurityScopes, sc2: SecurityScopes) -> str: + assert sc1 is not None + assert sc1 is sc2 + return "success" + + +@app.websocket("/websocket") +async def websocket(ws1: WebSocket, ws2: WebSocket) -> str: + assert ws1 is ws2 + await ws1.accept() + await ws1.send_text("success") + await ws1.close() + + +@pytest.mark.parametrize( + "url", + ( + "/request", + "/response", + "/background-tasks", + "/security-scopes", + ), +) +def test_duplicate_special_dependency(url: str) -> None: + assert TestClient(app).get(url).text == '"success"' + + +def test_duplicate_websocket_dependency() -> None: + with TestClient(app).websocket_connect("/websocket") as ws: + text = ws.receive_text() + assert text == "success"