From 4fb5d22a737d0b36a197673129b83a5d2b87605c Mon Sep 17 00:00:00 2001 From: Synrom Date: Sun, 8 Sep 2024 12:45:27 +0200 Subject: [PATCH] Support and test app.websocket_route --- fastapi/routing.py | 10 ++++++++++ tests/test_ignore_trailing_slash.py | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/fastapi/routing.py b/fastapi/routing.py index b06704574..33e3b7434 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1040,6 +1040,16 @@ class APIRouter(routing.Router): return decorator + def add_websocket_route( + self, + path: str, + endpoint: Callable[..., Any], + name: Optional[str] = None, + ) -> None: + if self.ignore_trailing_slash: + path = path.rstrip("/") + super().add_websocket_route(path, endpoint, name) + def add_api_websocket_route( self, path: str, diff --git a/tests/test_ignore_trailing_slash.py b/tests/test_ignore_trailing_slash.py index be35993ff..61ef7b496 100644 --- a/tests/test_ignore_trailing_slash.py +++ b/tests/test_ignore_trailing_slash.py @@ -29,6 +29,20 @@ async def websocket_endpoint_with_slash(websocket: WebSocket): await websocket.close() +@app.websocket_route("/websocket_route") +async def websocket_route_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket route") + await websocket.close() + + +@app.websocket_route("/websocket_route_2/") +async def websocket_route_endpoint_with_slash(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket route 2") + await websocket.close() + + @router.get("/example") def route_endpoint(): return {"msg": "Routing Example"} @@ -58,6 +72,10 @@ def test_ignoring_trailing_shlash_ws(): assert websocket.receive_text() == "Websocket" with client.websocket_connect("/websocket2") as websocket: assert websocket.receive_text() == "Websocket 2" + with client.websocket_connect("/websocket_route/") as websocket: + assert websocket.receive_text() == "Websocket route" + with client.websocket_connect("/websocket_route_2/") as websocket: + assert websocket.receive_text() == "Websocket route 2" def test_ignoring_trailing_routing():