diff --git a/fastapi/routing.py b/fastapi/routing.py index 6d252d817..67619bda5 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -271,6 +271,10 @@ class APIRouter(routing.Router): include_in_schema=route.include_in_schema, name=route.name, ) + elif isinstance(route, routing.WebSocketRoute): + self.add_websocket_route( + prefix + route.path, route.endpoint, name=route.name + ) def get( self, diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py new file mode 100644 index 000000000..d3c69ca1f --- /dev/null +++ b/tests/test_ws_router.py @@ -0,0 +1,53 @@ +from fastapi import APIRouter, FastAPI +from starlette.testclient import TestClient +from starlette.websockets import WebSocket + +router = APIRouter() +prefix_router = APIRouter() +app = FastAPI() + + +@app.websocket_route("/") +async def index(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Hello, world!") + await websocket.close() + + +@router.websocket_route("/router") +async def routerindex(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Hello, router!") + await websocket.close() + + +@prefix_router.websocket_route("/") +async def routerprefixindex(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Hello, router with prefix!") + await websocket.close() + + +app.include_router(router) +app.include_router(prefix_router, prefix="/prefix") + + +def test_app(): + client = TestClient(app) + with client.websocket_connect("/") as websocket: + data = websocket.receive_text() + assert data == "Hello, world!" + + +def test_router(): + client = TestClient(app) + with client.websocket_connect("/router") as websocket: + data = websocket.receive_text() + assert data == "Hello, router!" + + +def test_prefix_router(): + client = TestClient(app) + with client.websocket_connect("/prefix/") as websocket: + data = websocket.receive_text() + assert data == "Hello, router with prefix!"