From cca673514c33fc7bc69af42570c1b64b1d540dd4 Mon Sep 17 00:00:00 2001 From: Peter Volf Date: Tue, 8 Oct 2024 11:08:14 +0200 Subject: [PATCH] fix duplicate websocket dependency test --- tests/test_duplicate_special_dependencies.py | 21 ++++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/test_duplicate_special_dependencies.py b/tests/test_duplicate_special_dependencies.py index 82944da1f..8dfb57cb2 100644 --- a/tests/test_duplicate_special_dependencies.py +++ b/tests/test_duplicate_special_dependencies.py @@ -27,13 +27,6 @@ def background_tasks(t1: BackgroundTasks, t2: BackgroundTasks) -> str: return "success" -@app.get("/websocket") -def websocket(ws1: WebSocket, ws2: WebSocket) -> str: - assert ws1 is not None - assert ws1 is ws2 - return "success" - - @app.get("/security-scopes") def security_scopes(sc1: SecurityScopes, sc2: SecurityScopes) -> str: assert sc1 is not None @@ -41,7 +34,12 @@ def security_scopes(sc1: SecurityScopes, sc2: SecurityScopes) -> str: return "success" -client = TestClient(app) +@app.websocket("/websocket") +async def websocket(ws1: WebSocket, ws2: WebSocket) -> str: + assert ws1 is ws2 + await ws1.accept() + await ws1.send_text("success") + await ws1.close() @pytest.mark.parametrize( @@ -54,9 +52,10 @@ client = TestClient(app) ), ) def test_duplicate_special_dependency(url: str) -> None: - assert client.get(url).text == '"success"' + assert TestClient(app).get(url).text == '"success"' def test_duplicate_websocket_dependency() -> None: - # Raises exception if connect fails. - client.websocket_connect("/websocket") + with TestClient(app).websocket_connect("/websocket") as ws: + text = ws.receive_text() + assert text == "success"