diff --git a/tests/test_duplicate_special_dependencies.py b/tests/test_duplicate_special_dependencies.py index 82944da1f..8dfb57cb2 100644 --- a/tests/test_duplicate_special_dependencies.py +++ b/tests/test_duplicate_special_dependencies.py @@ -27,13 +27,6 @@ def background_tasks(t1: BackgroundTasks, t2: BackgroundTasks) -> str: return "success" -@app.get("/websocket") -def websocket(ws1: WebSocket, ws2: WebSocket) -> str: - assert ws1 is not None - assert ws1 is ws2 - return "success" - - @app.get("/security-scopes") def security_scopes(sc1: SecurityScopes, sc2: SecurityScopes) -> str: assert sc1 is not None @@ -41,7 +34,12 @@ def security_scopes(sc1: SecurityScopes, sc2: SecurityScopes) -> str: return "success" -client = TestClient(app) +@app.websocket("/websocket") +async def websocket(ws1: WebSocket, ws2: WebSocket) -> str: + assert ws1 is ws2 + await ws1.accept() + await ws1.send_text("success") + await ws1.close() @pytest.mark.parametrize( @@ -54,9 +52,10 @@ client = TestClient(app) ), ) def test_duplicate_special_dependency(url: str) -> None: - assert client.get(url).text == '"success"' + assert TestClient(app).get(url).text == '"success"' def test_duplicate_websocket_dependency() -> None: - # Raises exception if connect fails. - client.websocket_connect("/websocket") + with TestClient(app).websocket_connect("/websocket") as ws: + text = ws.receive_text() + assert text == "success"