From d8b6aa630cc9429fb24fcd0b6e3e626c7ea4e09f Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 1 Sep 2022 10:50:47 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20support=20for=20path=20par?= =?UTF-8?q?ameters=20in=20WebSockets=20(#3879)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sebastián Ramírez --- fastapi/routing.py | 4 ++-- tests/test_ws_router.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 1ac4b3880..233f79fcb 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -297,14 +297,14 @@ class APIWebSocketRoute(routing.WebSocketRoute): self.path = path self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name - self.dependant = get_dependant(path=path, call=self.endpoint) + self.path_regex, self.path_format, self.param_convertors = compile_path(path) + self.dependant = get_dependant(path=self.path_format, call=self.endpoint) self.app = websocket_session( get_websocket_app( dependant=self.dependant, dependency_overrides_provider=dependency_overrides_provider, ) ) - self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> Tuple[Match, Scope]: match, child_scope = super().matches(scope) diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py index fbca104a2..206d743ba 100644 --- a/tests/test_ws_router.py +++ b/tests/test_ws_router.py @@ -35,6 +35,14 @@ async def routerindex2(websocket: WebSocket): await websocket.close() +@router.websocket("/router/{pathparam:path}") +async def routerindexparams(websocket: WebSocket, pathparam: str, queryparam: str): + await websocket.accept() + await websocket.send_text(pathparam) + await websocket.send_text(queryparam) + await websocket.close() + + async def ws_dependency(): return "Socket Dependency" @@ -106,3 +114,14 @@ def test_router_ws_depends_with_override(): app.dependency_overrides[ws_dependency] = lambda: "Override" with client.websocket_connect("/router-ws-depends/") as websocket: assert websocket.receive_text() == "Override" + + +def test_router_with_params(): + client = TestClient(app) + with client.websocket_connect( + "/router/path/to/file?queryparam=a_query_param" + ) as websocket: + data = websocket.receive_text() + assert data == "path/to/file" + data = websocket.receive_text() + assert data == "a_query_param"