From 02441ff0313d5b471b662293244c53e712f1243f Mon Sep 17 00:00:00 2001 From: amitlissack Date: Mon, 30 Mar 2020 14:45:05 -0400 Subject: [PATCH] :bug: Fix dependency overrides in WebSockets (#1122) * add tests to test_ws_router to test dependencies and dependency overrides. * supply dependency_overrides_provider to APIWebSocketRoute upon creation --- fastapi/routing.py | 7 ++++++- tests/test_ws_router.py | 28 +++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index b90935e15..1ec0b693c 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -498,7 +498,12 @@ class APIRouter(routing.Router): def add_api_websocket_route( self, path: str, endpoint: Callable, name: str = None ) -> None: - route = APIWebSocketRoute(path, endpoint=endpoint, name=name) + route = APIWebSocketRoute( + path, + endpoint=endpoint, + name=name, + dependency_overrides_provider=self.dependency_overrides_provider, + ) self.routes.append(route) def websocket(self, path: str, name: str = None) -> Callable: diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py index fd19e650a..dd0456127 100644 --- a/tests/test_ws_router.py +++ b/tests/test_ws_router.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, FastAPI, WebSocket +from fastapi import APIRouter, Depends, FastAPI, WebSocket from fastapi.testclient import TestClient router = APIRouter() @@ -34,6 +34,19 @@ async def routerindex(websocket: WebSocket): await websocket.close() +async def ws_dependency(): + return "Socket Dependency" + + +@router.websocket("/router-ws-depends/") +async def router_ws_decorator_depends( + websocket: WebSocket, data=Depends(ws_dependency) +): + await websocket.accept() + await websocket.send_text(data) + await websocket.close() + + app.include_router(router) app.include_router(prefix_router, prefix="/prefix") @@ -64,3 +77,16 @@ def test_router2(): with client.websocket_connect("/router2") as websocket: data = websocket.receive_text() assert data == "Hello, router!" + + +def test_router_ws_depends(): + client = TestClient(app) + with client.websocket_connect("/router-ws-depends/") as websocket: + assert websocket.receive_text() == "Socket Dependency" + + +def test_router_ws_depends_with_override(): + client = TestClient(app) + app.dependency_overrides[ws_dependency] = lambda: "Override" + with client.websocket_connect("/router-ws-depends/") as websocket: + assert websocket.receive_text() == "Override"